Source code for neuroconv.datainterfaces.behavior.sleap.sleapdatainterface

import warnings
from pathlib import Path

import numpy as np
from pydantic import FilePath, validate_call
from pynwb.file import NWBFile

from .sleap_utils import extract_timestamps
from ....basetemporalalignmentinterface import BaseTemporalAlignmentInterface


[docs] class SLEAPInterface(BaseTemporalAlignmentInterface): """Data interface for SLEAP datasets.""" display_name = "SLEAP" keywords = ("pose estimation", "tracking", "video") associated_suffixes = (".slp", ".mp4") info = "Interface for SLEAP pose estimation datasets."
[docs] @classmethod def get_source_schema(cls) -> dict: source_schema = super().get_source_schema() source_schema["properties"]["file_path"]["description"] = "Path to the .slp file (the output of sleap)" source_schema["properties"]["video_file_path"][ "description" ] = "Path of the video for extracting timestamps (optional)." return source_schema
@validate_call def __init__( self, file_path: FilePath, *args, # TODO: change to * (keyword only) on or after August 2026 video_file_path: FilePath | None = None, verbose: bool = False, frames_per_second: float | None = None, ): """ Interface for writing sleap .slp files to nwb using the sleap-io library. Parameters ---------- file_path : FilePath Path to the .slp file (the output of sleap) verbose : bool, default: False controls verbosity. ``True`` by default. video_file_path : FilePath, optional The file path of the video for extracting timestamps. frames_per_second : float, optional The frames per second (fps) or sampling rate of the video. """ # Handle deprecated positional arguments if args: parameter_names = [ "video_file_path", "verbose", "frames_per_second", ] num_positional_args_before_args = 1 # file_path if len(args) > len(parameter_names): raise TypeError( f"__init__() takes at most {len(parameter_names) + num_positional_args_before_args + 1} positional arguments but " f"{len(args) + num_positional_args_before_args + 1} were given. " "Note: Positional arguments are deprecated and will be removed on or after August 2026. " "Please use keyword arguments." ) positional_values = dict(zip(parameter_names, args)) passed_as_positional = list(positional_values.keys()) warnings.warn( f"Passing arguments positionally to SLEAPInterface.__init__() is deprecated " f"and will be removed on or after August 2026. " f"The following arguments were passed positionally: {passed_as_positional}. " "Please use keyword arguments instead.", FutureWarning, stacklevel=2, ) video_file_path = positional_values.get("video_file_path", video_file_path) verbose = positional_values.get("verbose", verbose) frames_per_second = positional_values.get("frames_per_second", frames_per_second) # This import is to assure that the ndx_pose is in the global namespace when an pynwb.io object is created # For more detail, see https://github.com/rly/ndx-pose/issues/36 import ndx_pose # noqa: F401 self.file_path = Path(file_path) self.video_file_path = video_file_path self.video_sample_rate = frames_per_second self.verbose = verbose self._timestamps = None super().__init__(file_path=file_path)
[docs] def get_original_timestamps(self) -> np.ndarray: if self.video_file_path is None: raise ValueError( "Unable to fetch the original timestamps from the video! " "Please specify 'video_file_path' when initializing the interface." ) return np.array(extract_timestamps(self.video_file_path))
[docs] def get_timestamps(self) -> np.ndarray: timestamps = self._timestamps if self._timestamps is not None else self.get_original_timestamps() return timestamps
[docs] def set_aligned_timestamps(self, aligned_timestamps: np.ndarray): self._timestamps = aligned_timestamps
[docs] def add_to_nwbfile( self, nwbfile: NWBFile, metadata: dict | None = None, ): """ Conversion from DLC output files to nwb. Derived from sleap-io library. Parameters ---------- nwbfile: NWBFile nwb file to which the recording information is to be added metadata: dict metadata info for constructing the nwb file (optional). """ pose_estimation_metadata = dict() if self.video_file_path or self._timestamps: video_timestamps = self.get_timestamps() pose_estimation_metadata.update(video_timestamps=video_timestamps) if self.video_sample_rate: pose_estimation_metadata.update(video_sample_rate=self.video_sample_rate) from sleap_io import load_slp from sleap_io.io.nwb_predictions import append_nwb_data labels = load_slp(self.file_path) append_nwb_data(labels=labels, nwbfile=nwbfile, pose_estimation_metadata=pose_estimation_metadata)