Source code for neuroconv.datainterfaces.ecephys.baserecordingextractorinterface
from typing import Literal
import numpy as np
from pynwb import NWBFile
from pynwb.device import Device
from pynwb.ecephys import ElectricalSeries, ElectrodeGroup
from ...baseextractorinterface import BaseExtractorInterface
from ...utils import (
DeepDict,
get_base_schema,
get_schema_from_hdmf_class,
)
[docs]
class BaseRecordingExtractorInterface(BaseExtractorInterface):
"""Parent class for all RecordingExtractorInterfaces."""
keywords = ("extracellular electrophysiology", "voltage", "recording")
def _initialize_extractor(self, interface_kwargs: dict):
"""
Initialize and return the extractor instance for recording interfaces.
Extends the base implementation to also remove the 'es_key' parameter
which is specific to the recording interface, not the extractor.
Also adds 'all_annotations=True' to ensure all metadata is loaded.
Parameters
----------
interface_kwargs : dict
The source data parameters passed to the interface constructor.
Returns
-------
extractor_instance
An initialized recording extractor instance.
"""
self.extractor_kwargs = interface_kwargs.copy()
self.extractor_kwargs.pop("verbose", None)
self.extractor_kwargs.pop("es_key", None)
self.extractor_kwargs["all_annotations"] = True
extractor_class = self.get_extractor_class()
extractor_instance = extractor_class(**self.extractor_kwargs)
return extractor_instance
def __init__(self, verbose: bool = False, es_key: str = "ElectricalSeries", **source_data):
"""
Parameters
----------
verbose : bool, default: False
If True, will print out additional information.
es_key : str, default: "ElectricalSeries"
The key of this ElectricalSeries in the metadata dictionary.
source_data : dict
The key-value pairs of extractor-specific arguments.
"""
super().__init__(**source_data)
self.recording_extractor = self._extractor_instance
property_names = self.recording_extractor.get_property_keys()
# TODO remove this and go and change all the uses of channel_name once spikeinterface > 0.101.0 is released
if "channel_name" not in property_names and "channel_names" in property_names:
channel_names = self.recording_extractor.get_property("channel_names")
self.recording_extractor.set_property("channel_name", channel_names)
self.recording_extractor.delete_property("channel_names")
self.verbose = verbose
self.es_key = es_key
self._number_of_segments = self.recording_extractor.get_num_segments()
[docs]
def get_metadata_schema(self) -> dict:
"""
Compile metadata schema for the RecordingExtractor.
Returns
-------
dict
The metadata schema dictionary containing definitions for Device, ElectrodeGroup,
Electrodes, and optionally ElectricalSeries.
"""
metadata_schema = super().get_metadata_schema()
metadata_schema["properties"]["Ecephys"] = get_base_schema(tag="Ecephys")
metadata_schema["properties"]["Ecephys"]["required"] = ["Device", "ElectrodeGroup"]
metadata_schema["properties"]["Ecephys"]["properties"] = dict(
Device=dict(type="array", minItems=1, items={"$ref": "#/properties/Ecephys/definitions/Device"}),
ElectrodeGroup=dict(
type="array", minItems=1, items={"$ref": "#/properties/Ecephys/definitions/ElectrodeGroup"}
),
Electrodes=dict(
type="array",
minItems=0,
renderForm=False,
items={"$ref": "#/properties/Ecephys/definitions/Electrodes"},
),
)
# Schema definition for arrays
metadata_schema["properties"]["Ecephys"]["definitions"] = dict(
Device=get_schema_from_hdmf_class(Device),
ElectrodeGroup=get_schema_from_hdmf_class(ElectrodeGroup),
Electrodes=dict(
type="object",
additionalProperties=False,
required=["name"],
properties=dict(
name=dict(type="string", description="name of this electrodes column"),
description=dict(type="string", description="description of this electrodes column"),
),
),
)
if self.es_key is not None:
metadata_schema["properties"]["Ecephys"]["properties"].update(
{self.es_key: get_schema_from_hdmf_class(ElectricalSeries)}
)
return metadata_schema
[docs]
def get_metadata(self) -> DeepDict:
metadata = super().get_metadata()
from ...tools.spikeinterface.spikeinterface import _get_group_name
channel_groups_array = _get_group_name(recording=self.recording_extractor)
unique_channel_groups = set(channel_groups_array) if channel_groups_array is not None else ["ElectrodeGroup"]
electrode_metadata = [
dict(name=str(group_id), description="no description", location="unknown", device="DeviceEcephys")
for group_id in unique_channel_groups
]
metadata["Ecephys"] = dict(
Device=[dict(name="DeviceEcephys", description="no description")],
ElectrodeGroup=electrode_metadata,
)
if self.es_key is not None:
metadata["Ecephys"][self.es_key] = dict(
name=self.es_key, description=f"Acquisition traces for the {self.es_key}."
)
return metadata
@property
def channel_ids(self):
"Gets the channel ids of the data."
return self.recording_extractor.get_channel_ids()
[docs]
def get_original_timestamps(self) -> np.ndarray | list[np.ndarray]:
"""
Retrieve the original unaltered timestamps for the data in this interface.
This function should retrieve the data on-demand by re-initializing the IO.
Returns
-------
timestamps: numpy.ndarray or list of numpy.ndarray
The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned.
"""
new_recording = self._initialize_extractor(self.source_data)
if self._number_of_segments == 1:
return new_recording.get_times()
else:
return [
new_recording.get_times(segment_index=segment_index)
for segment_index in range(self._number_of_segments)
]
[docs]
def get_timestamps(self) -> np.ndarray | list[np.ndarray]:
"""
Retrieve the timestamps for the data in this interface.
Returns
-------
timestamps: numpy.ndarray or list of numpy.ndarray
The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned.
"""
if self._number_of_segments == 1:
return self.recording_extractor.get_times()
else:
return [
self.recording_extractor.get_times(segment_index=segment_index)
for segment_index in range(self._number_of_segments)
]
[docs]
def set_aligned_timestamps(self, aligned_timestamps: np.ndarray):
assert (
self._number_of_segments == 1
), "This recording has multiple segments; please use 'align_segment_timestamps' instead."
self.recording_extractor.set_times(times=aligned_timestamps, with_warning=False)
[docs]
def set_aligned_segment_timestamps(self, aligned_segment_timestamps: list[np.ndarray]):
"""
Replace all timestamps for all segments in this interface with those aligned to the common session start time.
Must be in units seconds relative to the common 'session_start_time'.
Parameters
----------
aligned_segment_timestamps : list of numpy.ndarray
The synchronized timestamps for segment of data in this interface.
"""
assert isinstance(
aligned_segment_timestamps, list
), "Recording has multiple segment! Please pass a list of timestamps to align each segment."
assert (
len(aligned_segment_timestamps) == self._number_of_segments
), f"The number of timestamp vectors ({len(aligned_segment_timestamps)}) does not match the number of segments ({self._number_of_segments})!"
for segment_index in range(self._number_of_segments):
self.recording_extractor.set_times(
times=aligned_segment_timestamps[segment_index],
segment_index=segment_index,
with_warning=False,
)
[docs]
def set_aligned_starting_time(self, aligned_starting_time: float):
if self._number_of_segments == 1:
self.set_aligned_timestamps(aligned_timestamps=self.get_timestamps() + aligned_starting_time)
else:
self.set_aligned_segment_timestamps(
aligned_segment_timestamps=[
segment_timestamps + aligned_starting_time for segment_timestamps in self.get_timestamps()
]
)
[docs]
def set_aligned_segment_starting_times(self, aligned_segment_starting_times: list[float]):
"""
Align the starting time for each segment in this interface relative to the common session start time.
Must be in units seconds relative to the common 'session_start_time'.
Parameters
----------
aligned_segment_starting_times : list of floats
The starting time for each segment of data in this interface.
"""
assert len(aligned_segment_starting_times) == self._number_of_segments, (
f"The length of the starting_times ({len(aligned_segment_starting_times)}) does not match the "
"number of segments ({self._number_of_segments})!"
)
if self._number_of_segments == 1:
self.set_aligned_starting_time(aligned_starting_time=aligned_segment_starting_times[0])
else:
aligned_segment_timestamps = [
segment_timestamps + aligned_segment_starting_time
for segment_timestamps, aligned_segment_starting_time in zip(
self.get_timestamps(), aligned_segment_starting_times
)
]
self.set_aligned_segment_timestamps(aligned_segment_timestamps=aligned_segment_timestamps)
[docs]
def set_probe(self, probe: "Probe | ProbeGroup", group_mode: Literal["by_shank", "by_probe"]):
"""
Set the probe information via a ProbeInterface object.
Parameters
----------
probe : probeinterface.Probe or probeinterface.ProbeGroup
The probe object(s). Can be a single Probe or a ProbeGroup containing multiple probes.
group_mode : {'by_shank', 'by_probe'}
How to group the channels for electrode group assignment in the NWB file:
- 'by_probe': Each probe becomes a separate electrode group. For a ProbeGroup with
multiple probes, each probe gets its own group (group 0, 1, 2, etc.). For a single
probe, all channels are assigned to group 0.
- 'by_shank': Each unique combination of probe and shank becomes a separate electrode
group. Requires that shank_ids are defined for all probes. Groups are assigned
sequentially for each unique (probe_index, shank_id) pair.
The resulting groups determine how electrode groups and electrodes are organized
in the NWB file, with each group corresponding to one ElectrodeGroup.
"""
# Set the probe to the recording extractor
self.recording_extractor._set_probes(
probe,
in_place=True,
group_mode=group_mode,
)
# Spike interface sets the "group" property
# But neuroconv allows "group_name" property to override spike interface "group" value
# So we re-set this here to avoid a conflict
self.recording_extractor.set_property("group_name", self.recording_extractor.get_property("group").astype(str))
[docs]
def has_probe(self) -> bool:
"""
Check if the recording extractor has probe information.
Returns
-------
bool
True if the recording extractor has probe information, False otherwise.
"""
return self.recording_extractor.has_probe()
[docs]
def align_by_interpolation(
self,
unaligned_timestamps: np.ndarray,
aligned_timestamps: np.ndarray,
):
if self._number_of_segments == 1:
self.set_aligned_timestamps(
aligned_timestamps=np.interp(x=self.get_timestamps(), xp=unaligned_timestamps, fp=aligned_timestamps)
)
else:
raise NotImplementedError("Multi-segment support for aligning by interpolation has not been added yet.")
[docs]
def add_to_nwbfile(
self,
nwbfile: NWBFile,
metadata: dict | None = None,
*,
stub_test: bool = False,
write_as: Literal["raw", "lfp", "processed"] = "raw",
write_electrical_series: bool = True,
iterator_type: str | None = "v2",
iterator_options: dict | None = None,
always_write_timestamps: bool = False,
):
"""
Primary function for converting raw (unprocessed) RecordingExtractor data to the NWB standard.
Parameters
----------
nwbfile : NWBFile
NWBFile to which the recording information is to be added
metadata : dict, optional
metadata info for constructing the NWB file.
Should be of the format::
metadata['Ecephys']['ElectricalSeries'] = dict(name=my_name, description=my_description)
stub_test : bool, default: False
If True, will truncate the data to run the conversion faster and take up less memory.
write_as : {'raw', 'processed', 'lfp'}, default='raw'
Specifies how to save the trace data in the NWB file. Options are:
- 'raw': Save the data in the acquisition group.
- 'processed': Save the data as FilteredEphys in a processing module.
- 'lfp': Save the data as LFP in a processing module.
write_electrical_series : bool, default: True
Electrical series are written in acquisition. If False, only device, electrode_groups,
and electrodes are written to NWB.
iterator_type : {'v2', None}, default: 'v2'
The type of iterator for chunked data writing.
'v2': Uses iterative write with control over chunking and progress bars.
None: Loads all data into memory before writing (not recommended for large datasets).
iterator_options : dict, optional
Options for controlling iterative write when iterator_type='v2'.
See the `pynwb tutorial on iterative write
<https://pynwb.readthedocs.io/en/stable/tutorials/advanced_io/plot_iterative_write.html#sphx-glr-tutorials-advanced-io-plot-iterative-write-py>`_
for more information on chunked data writing.
Available options:
* buffer_gb : float, default: 1.0
RAM to use for buffering data chunks in GB. Recommended to be as much free RAM as available.
* buffer_shape : tuple, optional
Manual specification of buffer shape. Must be a multiple of chunk_shape along each axis.
Cannot be set if buffer_gb is specified.
* display_progress : bool, default: False
Enable tqdm progress bar during data write.
* progress_bar_options : dict, optional
Additional options passed to tqdm progress bar.
See https://github.com/tqdm/tqdm#parameters for all tqdm options.
Note: To configure chunk size and compression, use the backend configuration system
via ``get_default_backend_configuration()`` and ``configure_backend()`` after calling
this method. See the backend configuration documentation for details.
always_write_timestamps : bool, default: False
Set to True to always write timestamps.
By default (False), the function checks if the timestamps are uniformly sampled, and if so, stores the data
using a regular sampling rate instead of explicit timestamps. If set to True, timestamps will be written
explicitly, regardless of whether the sampling rate is uniform.
"""
from ...tools.spikeinterface import (
_stub_recording,
add_recording_metadata_to_nwbfile,
add_recording_to_nwbfile,
)
recording = self.recording_extractor
if stub_test:
recording = _stub_recording(recording=recording)
metadata = metadata or self.get_metadata()
if write_electrical_series:
add_recording_to_nwbfile(
recording=recording,
nwbfile=nwbfile,
metadata=metadata,
write_as=write_as,
es_key=self.es_key,
iterator_type=iterator_type,
iterator_options=iterator_options,
always_write_timestamps=always_write_timestamps,
)
else:
add_recording_metadata_to_nwbfile(
recording=recording,
nwbfile=nwbfile,
metadata=metadata,
)