diff --git a/pyshimmer/reader/binary_reader.py b/pyshimmer/reader/binary_reader.py index 1db54ce..5fa8ae5 100644 --- a/pyshimmer/reader/binary_reader.py +++ b/pyshimmer/reader/binary_reader.py @@ -15,6 +15,8 @@ # along with this program. If not, see . from __future__ import annotations +from abc import ABC, abstractmethod + import struct from typing import BinaryIO @@ -47,20 +49,28 @@ TRIAXCAL_OFFSET_SCALING, TRIAXCAL_GAIN_SCALING, TRIAXCAL_ALIGNMENT_SCALING, + SYNC_OFFSET_SIZE, ) -class ShimmerBinaryReader(FileIOBase): +class ShimmerBinaryReader(FileIOBase, ABC): - def __init__(self, fp: BinaryIO): + def __init__(self, fp: BinaryIO, batch_size: int = -1): super().__init__(fp) self._sensors = [] self._channels = [] + self._channel_dtypes = [] self._sr = 0 self._rtc_diff = 0 self._start_ts = 0 self._trial_config = 0 + self._batch_size = batch_size + + # attributes required for filtering + self._active_sensors = [] + self._active_channels = [] + self._active_channel_dtypes = [] self._read_header() @@ -80,7 +90,13 @@ def _read_header(self) -> None: self._trial_config = self._read_trial_config() self._exg_regs = self._read_exg_regs() - self._samples_per_block, self._block_size = self._calculate_block_size() + self._active_sensors = self._sensors.copy() + self._active_channels = self._channels.copy() + self._active_channel_dtypes = self._channel_dtypes.copy() + + self._samples_per_block, self._block_size, self.sample_size = ( + self._calculate_block_size() + ) def _read_sample_rate(self) -> int: self._seek(SR_OFFSET) @@ -122,7 +138,7 @@ def _calculate_block_size(self): num_samples = int((BLOCK_LEN - sync_stamp) / sample_size) block_len = num_samples * sample_size + sync_stamp - return num_samples, block_len + return num_samples, block_len, sample_size def _read_sync_offset(self) -> int | None: # For this read operation we assume that every synchronization offset is @@ -145,17 +161,18 @@ def _read_sample(self) -> list: for ch, dtype in zip(self._channels, self._channel_dtypes): val_bin = self._read(dtype.size) - ch_values.append(dtype.decode(val_bin)) + if ch in self._active_channels: + ch_values.append(dtype.decode(val_bin)) return ch_values def _read_data_block(self) -> tuple[list[list], int]: - sync_tuple = None + sync_offset = None samples = [] try: if self.has_sync: - sync_tuple = self._read_sync_offset() + sync_offset = self._read_sync_offset() for i in range(self._samples_per_block): sample = self._read_sample() @@ -163,7 +180,7 @@ def _read_data_block(self) -> tuple[list[list], int]: except IOError: pass - return samples, sync_tuple + return samples, sync_offset def _read_contents(self) -> tuple[list, list[tuple[int, int]]]: sync_offsets = [] @@ -179,7 +196,6 @@ def _read_contents(self) -> tuple[list, list[tuple[int, int]]]: samples += block_samples sample_ctr += len(block_samples) - if len(block_samples) < self.samples_per_block: # We have reached EOF break @@ -208,12 +224,13 @@ def _read_triaxcal_params( return offset, gain, alignment - def read_data(self): - samples, sync_offsets = self._read_contents() + @abstractmethod + def read_data(self): ... + def _finalize_data(self, samples, sync_offsets): samples_per_ch = list(zip(*samples)) arr_per_ch = [np.array(s) for s in samples_per_ch] - samples_dict = dict(zip(self._channels, arr_per_ch)) + samples_dict = dict(zip(self._active_channels, arr_per_ch)) if self.has_sync and len(sync_offsets) > 0: off_index, offset = list(zip(*sync_offsets)) @@ -221,7 +238,7 @@ def read_data(self): offset_arr = np.array(offset) sync_data = (off_index_arr, offset_arr) else: - sync_data = ((), ()) + sync_data = () return samples_dict, sync_data @@ -254,11 +271,17 @@ def samples_per_block(self) -> int: @property def enabled_sensors(self) -> list[ESensorGroup]: - return self._sensors + return self._active_sensors + + @enabled_sensors.setter + def enabled_sensors(self, sensor_filter: list[ESensorGroup]) -> None: + self._active_sensors = sensor_filter + self._active_channels = self.get_data_channels(self._active_sensors) + self._active_channel_dtypes = get_ch_dtypes(self._active_channels) @property def enabled_channels(self) -> list[EChannelType]: - return self._channels + return self._active_channels @property def has_global_clock(self) -> bool: @@ -287,3 +310,72 @@ def exg_reg1(self) -> ExGRegister: @property def exg_reg2(self) -> ExGRegister: return self.get_exg_reg(1) + + +class ShimmerBinaryFileReader(ShimmerBinaryReader): + def read_data(self): + samples, sync_offsets = self._read_contents() + return self._finalize_data(samples, sync_offsets) + + +class ShimmerBinaryStreamReader(ShimmerBinaryReader): + def __init__(self, fp: BinaryIO, batch_size: int) -> None: + super().__init__(fp, batch_size) + + self._buffer: list[list] = [] + self._read_offset: int = 0 + self._buffered_sync_offset: int = 0 + + def read_data(self): + samples, sync_offsets = self._get() + return self._finalize_data(samples, sync_offsets) + + def _read_sample(self) -> list: + ch_values = super()._read_sample() + self._read_offset += self.sample_size + return ch_values + + def _read_data_block(self) -> tuple[list[list], int]: + self._read_offset += SYNC_OFFSET_SIZE if self.has_sync else 0 + return super()._read_data_block() + + def _get(self) -> tuple[list, list[tuple[int, int]]]: + batch_size = self._batch_size + if batch_size <= 0: + return [], [] + + if batch_size <= len(self._buffer): + samples = self._buffer[:batch_size] + self._buffer = self._buffer[batch_size:] + return samples, [(0, self._buffered_sync_offset)] if self.has_sync else [] + + sync_offsets = [(0, self._buffered_sync_offset)] + samples = self._buffer + sample_ctr = len(samples) + batch_size -= sample_ctr + + target_idx = DATA_LOG_OFFSET + self._read_offset + if not self._tell() == target_idx: + self._seek(target_idx) + + while True: + block_samples, sync_offset = self._read_data_block() + + if sync_offset is not None: + sync_offsets += [(sample_ctr, sync_offset)] + + samples += block_samples + sample_ctr += len(block_samples) + batch_size -= len(block_samples) + if batch_size <= 0: + self._buffer = samples[batch_size:] + samples = samples[:batch_size] + self._buffered_sync_offset = sync_offset + break + + if len(block_samples) < self.samples_per_block: + # We have reached EOF + self._batch_size = 0 + break + + return samples, sync_offsets diff --git a/pyshimmer/reader/reader_const.py b/pyshimmer/reader/reader_const.py index cf57399..7c2c09f 100644 --- a/pyshimmer/reader/reader_const.py +++ b/pyshimmer/reader/reader_const.py @@ -75,3 +75,5 @@ EXG_ADC_OFFSET = 0.0 EXG_ADC_REF_VOLT = 2.42 # Volts + +SYNC_OFFSET_SIZE = 9 diff --git a/pyshimmer/reader/shimmer_reader.py b/pyshimmer/reader/shimmer_reader.py index 01ec761..fbf85c7 100644 --- a/pyshimmer/reader/shimmer_reader.py +++ b/pyshimmer/reader/shimmer_reader.py @@ -16,7 +16,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import BinaryIO +from typing import BinaryIO, Any import numpy as np @@ -27,7 +27,11 @@ EChannelType, ) from pyshimmer.dev.exg import is_exg_ch, get_exg_ch, ExGRegister -from pyshimmer.reader.binary_reader import ShimmerBinaryReader +from pyshimmer.reader.binary_reader import ( + ShimmerBinaryReader, + ShimmerBinaryFileReader, + ShimmerBinaryStreamReader, +) from pyshimmer.reader.reader_const import ( EXG_ADC_REF_VOLT, EXG_ADC_OFFSET, @@ -149,9 +153,14 @@ def __init__( sync: bool = True, post_process: bool = True, processors: list[ChannelPostProcessor] = None, + batch_size: int = -1, ): if fp is not None: - self._bin_reader = ShimmerBinaryReader(fp) + self._bin_reader = ( + ShimmerBinaryFileReader(fp) + if batch_size < 1 + else ShimmerBinaryStreamReader(fp, batch_size) + ) elif bin_reader is not None: self._bin_reader = bin_reader else: @@ -208,8 +217,9 @@ def _process_signals( return result - def load_file_data(self): - samples, sync_offsets = self._bin_reader.read_data() + def _finalize_data( + self, samples: dict[EChannelType, Any], sync_offsets: list[tuple[int, int]] + ) -> tuple[int | float | np.ndarray, dict[EChannelType, Any]]: ts_raw = samples.pop(EChannelType.TIMESTAMP) ts_unwrapped = unwrap_device_timestamps(ts_raw) @@ -218,12 +228,29 @@ def load_file_data(self): if self._sync and self._bin_reader.has_sync: ts_sane = self._apply_synchronization(ts_sane, *sync_offsets) + output_samples = None if self._pp: - self._ch_samples = self._process_signals(samples) + output_samples = self._process_signals(samples) else: - self._ch_samples = samples + output_samples = samples + + timestamps = ticks2sec(ts_sane) + return timestamps, output_samples + + def get_batch(self): + samples, sync_offsets = self._bin_reader.read_data() + if len(samples) == len(sync_offsets) == 0: + return [], [] - self._ts = ticks2sec(ts_sane) + return self._finalize_data(samples, sync_offsets) + + def load_file_data(self) -> tuple | None: + timestamps, samples = self.get_batch() + if type(self._bin_reader) == ShimmerBinaryStreamReader: + return timestamps, samples + + self._ts, self._ch_samples = timestamps, samples + return None def get_exg_reg(self, chip_id: int) -> ExGRegister: return self._bin_reader.get_exg_reg(chip_id) @@ -231,8 +258,7 @@ def get_exg_reg(self, chip_id: int) -> ExGRegister: def __getitem__(self, item: EChannelType) -> np.ndarray: if item == EChannelType.TIMESTAMP: return self.timestamp - - return self._ch_samples[item] + return self._ch_samples.get(item, np.empty(0)) @property def timestamp(self) -> np.ndarray: @@ -254,3 +280,11 @@ def exg_reg1(self) -> ExGRegister: @property def exg_reg2(self) -> ExGRegister: return self.get_exg_reg(1) + + @property + def enabled_sensors(self) -> list[ESensorGroup]: + return self._bin_reader.enabled_sensors + + @enabled_sensors.setter + def enabled_sensors(self, enabled_sensors: list[ESensorGroup]) -> None: + self._bin_reader.enabled_sensors = enabled_sensors