Source code for neuroconv.datainterfaces.ecephys.baserecordingextractorinterface

from typing import Literal

import numpy as np
from pynwb import NWBFile
from pynwb.device import Device
from pynwb.ecephys import ElectricalSeries, ElectrodeGroup

from ...baseextractorinterface import BaseExtractorInterface
from ...utils import (
    DeepDict,
    get_base_schema,
    get_schema_from_hdmf_class,
)


[docs] class BaseRecordingExtractorInterface(BaseExtractorInterface): """Parent class for all RecordingExtractorInterfaces.""" keywords = ("extracellular electrophysiology", "voltage", "recording") def _initialize_extractor(self, interface_kwargs: dict): """ Initialize and return the extractor instance for recording interfaces. Extends the base implementation to also remove the 'es_key' parameter which is specific to the recording interface, not the extractor. Also adds 'all_annotations=True' to ensure all metadata is loaded. Parameters ---------- interface_kwargs : dict The source data parameters passed to the interface constructor. Returns ------- extractor_instance An initialized recording extractor instance. """ self.extractor_kwargs = interface_kwargs.copy() self.extractor_kwargs.pop("verbose", None) self.extractor_kwargs.pop("es_key", None) self.extractor_kwargs["all_annotations"] = True extractor_class = self.get_extractor_class() extractor_instance = extractor_class(**self.extractor_kwargs) return extractor_instance def __init__(self, verbose: bool = False, es_key: str = "ElectricalSeries", **source_data): """ Parameters ---------- verbose : bool, default: False If True, will print out additional information. es_key : str, default: "ElectricalSeries" The key of this ElectricalSeries in the metadata dictionary. source_data : dict The key-value pairs of extractor-specific arguments. """ super().__init__(**source_data) self.recording_extractor = self._extractor_instance property_names = self.recording_extractor.get_property_keys() # TODO remove this and go and change all the uses of channel_name once spikeinterface > 0.101.0 is released if "channel_name" not in property_names and "channel_names" in property_names: channel_names = self.recording_extractor.get_property("channel_names") self.recording_extractor.set_property("channel_name", channel_names) self.recording_extractor.delete_property("channel_names") self.verbose = verbose self.es_key = es_key self._number_of_segments = self.recording_extractor.get_num_segments()
[docs] def get_metadata_schema(self) -> dict: """ Compile metadata schema for the RecordingExtractor. Returns ------- dict The metadata schema dictionary containing definitions for Device, ElectrodeGroup, Electrodes, and optionally ElectricalSeries. """ metadata_schema = super().get_metadata_schema() metadata_schema["properties"]["Ecephys"] = get_base_schema(tag="Ecephys") metadata_schema["properties"]["Ecephys"]["required"] = ["Device", "ElectrodeGroup"] metadata_schema["properties"]["Ecephys"]["properties"] = dict( Device=dict(type="array", minItems=1, items={"$ref": "#/properties/Ecephys/definitions/Device"}), ElectrodeGroup=dict( type="array", minItems=1, items={"$ref": "#/properties/Ecephys/definitions/ElectrodeGroup"} ), Electrodes=dict( type="array", minItems=0, renderForm=False, items={"$ref": "#/properties/Ecephys/definitions/Electrodes"}, ), ) # Schema definition for arrays metadata_schema["properties"]["Ecephys"]["definitions"] = dict( Device=get_schema_from_hdmf_class(Device), ElectrodeGroup=get_schema_from_hdmf_class(ElectrodeGroup), Electrodes=dict( type="object", additionalProperties=False, required=["name"], properties=dict( name=dict(type="string", description="name of this electrodes column"), description=dict(type="string", description="description of this electrodes column"), ), ), ) if self.es_key is not None: metadata_schema["properties"]["Ecephys"]["properties"].update( {self.es_key: get_schema_from_hdmf_class(ElectricalSeries)} ) return metadata_schema
[docs] def get_metadata(self) -> DeepDict: metadata = super().get_metadata() from ...tools.spikeinterface.spikeinterface import _get_group_name channel_groups_array = _get_group_name(recording=self.recording_extractor) unique_channel_groups = set(channel_groups_array) if channel_groups_array is not None else ["ElectrodeGroup"] electrode_metadata = [ dict(name=str(group_id), description="no description", location="unknown", device="DeviceEcephys") for group_id in unique_channel_groups ] metadata["Ecephys"] = dict( Device=[dict(name="DeviceEcephys", description="no description")], ElectrodeGroup=electrode_metadata, ) if self.es_key is not None: metadata["Ecephys"][self.es_key] = dict( name=self.es_key, description=f"Acquisition traces for the {self.es_key}." ) return metadata
@property def channel_ids(self): "Gets the channel ids of the data." return self.recording_extractor.get_channel_ids()
[docs] def get_original_timestamps(self) -> np.ndarray | list[np.ndarray]: """ Retrieve the original unaltered timestamps for the data in this interface. This function should retrieve the data on-demand by re-initializing the IO. Returns ------- timestamps: numpy.ndarray or list of numpy.ndarray The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned. """ new_recording = self._initialize_extractor(self.source_data) if self._number_of_segments == 1: return new_recording.get_times() else: return [ new_recording.get_times(segment_index=segment_index) for segment_index in range(self._number_of_segments) ]
[docs] def get_timestamps(self) -> np.ndarray | list[np.ndarray]: """ Retrieve the timestamps for the data in this interface. Returns ------- timestamps: numpy.ndarray or list of numpy.ndarray The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned. """ if self._number_of_segments == 1: return self.recording_extractor.get_times() else: return [ self.recording_extractor.get_times(segment_index=segment_index) for segment_index in range(self._number_of_segments) ]
[docs] def set_aligned_timestamps(self, aligned_timestamps: np.ndarray): assert ( self._number_of_segments == 1 ), "This recording has multiple segments; please use 'align_segment_timestamps' instead." self.recording_extractor.set_times(times=aligned_timestamps, with_warning=False)
[docs] def set_aligned_segment_timestamps(self, aligned_segment_timestamps: list[np.ndarray]): """ Replace all timestamps for all segments in this interface with those aligned to the common session start time. Must be in units seconds relative to the common 'session_start_time'. Parameters ---------- aligned_segment_timestamps : list of numpy.ndarray The synchronized timestamps for segment of data in this interface. """ assert isinstance( aligned_segment_timestamps, list ), "Recording has multiple segment! Please pass a list of timestamps to align each segment." assert ( len(aligned_segment_timestamps) == self._number_of_segments ), f"The number of timestamp vectors ({len(aligned_segment_timestamps)}) does not match the number of segments ({self._number_of_segments})!" for segment_index in range(self._number_of_segments): self.recording_extractor.set_times( times=aligned_segment_timestamps[segment_index], segment_index=segment_index, with_warning=False, )
[docs] def set_aligned_starting_time(self, aligned_starting_time: float): if self._number_of_segments == 1: self.set_aligned_timestamps(aligned_timestamps=self.get_timestamps() + aligned_starting_time) else: self.set_aligned_segment_timestamps( aligned_segment_timestamps=[ segment_timestamps + aligned_starting_time for segment_timestamps in self.get_timestamps() ] )
[docs] def set_aligned_segment_starting_times(self, aligned_segment_starting_times: list[float]): """ Align the starting time for each segment in this interface relative to the common session start time. Must be in units seconds relative to the common 'session_start_time'. Parameters ---------- aligned_segment_starting_times : list of floats The starting time for each segment of data in this interface. """ assert len(aligned_segment_starting_times) == self._number_of_segments, ( f"The length of the starting_times ({len(aligned_segment_starting_times)}) does not match the " "number of segments ({self._number_of_segments})!" ) if self._number_of_segments == 1: self.set_aligned_starting_time(aligned_starting_time=aligned_segment_starting_times[0]) else: aligned_segment_timestamps = [ segment_timestamps + aligned_segment_starting_time for segment_timestamps, aligned_segment_starting_time in zip( self.get_timestamps(), aligned_segment_starting_times ) ] self.set_aligned_segment_timestamps(aligned_segment_timestamps=aligned_segment_timestamps)
[docs] def set_probe(self, probe: "Probe | ProbeGroup", group_mode: Literal["by_shank", "by_probe"]): """ Set the probe information via a ProbeInterface object. Parameters ---------- probe : probeinterface.Probe or probeinterface.ProbeGroup The probe object(s). Can be a single Probe or a ProbeGroup containing multiple probes. group_mode : {'by_shank', 'by_probe'} How to group the channels for electrode group assignment in the NWB file: - 'by_probe': Each probe becomes a separate electrode group. For a ProbeGroup with multiple probes, each probe gets its own group (group 0, 1, 2, etc.). For a single probe, all channels are assigned to group 0. - 'by_shank': Each unique combination of probe and shank becomes a separate electrode group. Requires that shank_ids are defined for all probes. Groups are assigned sequentially for each unique (probe_index, shank_id) pair. The resulting groups determine how electrode groups and electrodes are organized in the NWB file, with each group corresponding to one ElectrodeGroup. """ # Set the probe to the recording extractor self.recording_extractor._set_probes( probe, in_place=True, group_mode=group_mode, ) # Spike interface sets the "group" property # But neuroconv allows "group_name" property to override spike interface "group" value # So we re-set this here to avoid a conflict self.recording_extractor.set_property("group_name", self.recording_extractor.get_property("group").astype(str))
[docs] def has_probe(self) -> bool: """ Check if the recording extractor has probe information. Returns ------- bool True if the recording extractor has probe information, False otherwise. """ return self.recording_extractor.has_probe()
[docs] def align_by_interpolation( self, unaligned_timestamps: np.ndarray, aligned_timestamps: np.ndarray, ): if self._number_of_segments == 1: self.set_aligned_timestamps( aligned_timestamps=np.interp(x=self.get_timestamps(), xp=unaligned_timestamps, fp=aligned_timestamps) ) else: raise NotImplementedError("Multi-segment support for aligning by interpolation has not been added yet.")
[docs] def add_to_nwbfile( self, nwbfile: NWBFile, metadata: dict | None = None, *, stub_test: bool = False, write_as: Literal["raw", "lfp", "processed"] = "raw", write_electrical_series: bool = True, iterator_type: str | None = "v2", iterator_options: dict | None = None, always_write_timestamps: bool = False, ): """ Primary function for converting raw (unprocessed) RecordingExtractor data to the NWB standard. Parameters ---------- nwbfile : NWBFile NWBFile to which the recording information is to be added metadata : dict, optional metadata info for constructing the NWB file. Should be of the format:: metadata['Ecephys']['ElectricalSeries'] = dict(name=my_name, description=my_description) stub_test : bool, default: False If True, will truncate the data to run the conversion faster and take up less memory. write_as : {'raw', 'processed', 'lfp'}, default='raw' Specifies how to save the trace data in the NWB file. Options are: - 'raw': Save the data in the acquisition group. - 'processed': Save the data as FilteredEphys in a processing module. - 'lfp': Save the data as LFP in a processing module. write_electrical_series : bool, default: True Electrical series are written in acquisition. If False, only device, electrode_groups, and electrodes are written to NWB. iterator_type : {'v2', None}, default: 'v2' The type of iterator for chunked data writing. 'v2': Uses iterative write with control over chunking and progress bars. None: Loads all data into memory before writing (not recommended for large datasets). iterator_options : dict, optional Options for controlling iterative write when iterator_type='v2'. See the `pynwb tutorial on iterative write <https://pynwb.readthedocs.io/en/stable/tutorials/advanced_io/plot_iterative_write.html#sphx-glr-tutorials-advanced-io-plot-iterative-write-py>`_ for more information on chunked data writing. Available options: * buffer_gb : float, default: 1.0 RAM to use for buffering data chunks in GB. Recommended to be as much free RAM as available. * buffer_shape : tuple, optional Manual specification of buffer shape. Must be a multiple of chunk_shape along each axis. Cannot be set if buffer_gb is specified. * display_progress : bool, default: False Enable tqdm progress bar during data write. * progress_bar_options : dict, optional Additional options passed to tqdm progress bar. See https://github.com/tqdm/tqdm#parameters for all tqdm options. Note: To configure chunk size and compression, use the backend configuration system via ``get_default_backend_configuration()`` and ``configure_backend()`` after calling this method. See the backend configuration documentation for details. always_write_timestamps : bool, default: False Set to True to always write timestamps. By default (False), the function checks if the timestamps are uniformly sampled, and if so, stores the data using a regular sampling rate instead of explicit timestamps. If set to True, timestamps will be written explicitly, regardless of whether the sampling rate is uniform. """ from ...tools.spikeinterface import ( _stub_recording, add_recording_metadata_to_nwbfile, add_recording_to_nwbfile, ) recording = self.recording_extractor if stub_test: recording = _stub_recording(recording=recording) metadata = metadata or self.get_metadata() if write_electrical_series: add_recording_to_nwbfile( recording=recording, nwbfile=nwbfile, metadata=metadata, write_as=write_as, es_key=self.es_key, iterator_type=iterator_type, iterator_options=iterator_options, always_write_timestamps=always_write_timestamps, ) else: add_recording_metadata_to_nwbfile( recording=recording, nwbfile=nwbfile, metadata=metadata, )