From 4f90a50b438d3bcd31e3335987d2eaa4868cbab3 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 29 Jan 2025 12:45:01 -0500 Subject: [PATCH 1/4] ENH: Support BDF export --- mne/export/_edf.py | 8 ++- mne/export/_export.py | 5 +- mne/export/tests/test_export.py | 113 +++++++++++++++++------------- mne/utils/check.py | 6 +- tools/install_pre_requirements.sh | 2 +- 5 files changed, 77 insertions(+), 57 deletions(-) diff --git a/mne/export/_edf.py b/mne/export/_edf.py index e50b05f7056..65cbfc14980 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -11,7 +11,6 @@ from ..utils import _check_edfio_installed, warn _check_edfio_installed() -from edfio import Edf, EdfAnnotation, EdfSignal, Patient, Recording # noqa: E402 # copied from edfio (Apache license) @@ -29,12 +28,17 @@ def _round_float_to_8_characters( return round_func(value * factor) / factor -def _export_raw(fname, raw, physical_range, add_ch_type): +def _export_raw(fname, raw, physical_range, add_ch_type, *, fmt="edf"): """Export Raw objects to EDF files. TODO: if in future the Info object supports transducer or technician information, allow writing those here. """ + assert fmt in ("edf", "bdf"), fmt + _check_edfio_installed(min_version="0.4.6" if fmt == "bdf" else None) + + from edfio import Edf, EdfAnnotation, EdfSignal, Patient, Recording # noqa: E402 + # get voltage-based data in uV units = dict( eeg="uV", ecog="uV", seeg="uV", eog="uV", ecg="uV", emg="uV", bio="uV", dbs="uV" diff --git a/mne/export/_export.py b/mne/export/_export.py index 4b93fda917e..99b657edfa5 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -58,6 +58,7 @@ def export_raw( supported_export_formats = { # format : (extensions,) "eeglab": ("set",), "edf": ("edf",), + "bdf": ("bdf",), "brainvision": ( "eeg", "vmrk", @@ -77,10 +78,10 @@ def export_raw( from ._eeglab import _export_raw _export_raw(fname, raw) - elif fmt == "edf": + elif fmt in ("edf", "bdf"): from ._edf import _export_raw - _export_raw(fname, raw, physical_range, add_ch_type) + _export_raw(fname, raw, physical_range, add_ch_type, fmt=fmt) elif fmt == "brainvision": from ._brainvision import _export_raw diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 6f712923c7d..88cd05fb320 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -25,8 +25,8 @@ from mne.fixes import _compare_version from mne.io import ( RawArray, + read_raw, read_raw_brainvision, - read_raw_edf, read_raw_eeglab, read_raw_fif, ) @@ -190,10 +190,23 @@ def _create_raw_for_edf_tests(stim_channel_index=None): edfio_mark = pytest.mark.skipif( not _check_edfio_installed(strict=False), reason="requires edfio" ) +edfio_bdf_mark = pytest.mark.skipif( + not _check_edfio_installed(strict=False, min_version="0.4.6"), + reason="requires edfio with bdf support", +) + + +edf_params = pytest.mark.parametrize( + "fmt", + [ + pytest.param("edf", marks=edfio_mark), + pytest.param("bdf", marks=edfio_bdf_mark), + ], +) -@edfio_mark() -def test_double_export_edf(tmp_path): +@edf_params +def test_double_export_edf(tmp_path, fmt): """Test exporting an EDF file multiple times.""" raw = _create_raw_for_edf_tests(stim_channel_index=2) raw.info.set_meas_date("2023-09-04 14:53:09.000") @@ -212,13 +225,13 @@ def test_double_export_edf(tmp_path): ) # export once - temp_fname = tmp_path / "test.edf" + temp_fname = tmp_path / f"test.{fmt}" raw.export(temp_fname, add_ch_type=True) - raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) + raw_read = read_raw(temp_fname, infer_types=True, preload=True) # export again raw_read.export(temp_fname, add_ch_type=True, overwrite=True) - raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) + raw_read = read_raw(temp_fname, infer_types=True, preload=True) assert raw.ch_names == raw_read.ch_names assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) @@ -233,8 +246,8 @@ def test_double_export_edf(tmp_path): assert_array_equal(orig_ch_types, read_ch_types) -@edfio_mark() -def test_edf_physical_range(tmp_path): +@edf_params +def test_edf_physical_range(tmp_path, fmt): """Test exporting an EDF file with different physical range settings.""" ch_types = ["eeg"] * 4 ch_names = np.arange(len(ch_types)).astype(str).tolist() @@ -247,22 +260,22 @@ def test_edf_physical_range(tmp_path): raw = RawArray(data, info) # export with physical range per channel type (default) - temp_fname = tmp_path / "test_auto.edf" + temp_fname = tmp_path / f"test_auto.{fmt}" raw.export(temp_fname) - raw_read = read_raw_edf(temp_fname, preload=True) + raw_read = read_raw(temp_fname, preload=True) with pytest.raises(AssertionError, match="Arrays are not almost equal"): assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) # export with physical range per channel - temp_fname = tmp_path / "test_per_channel.edf" + temp_fname = tmp_path / f"test_per_channel.{fmt}" raw.export(temp_fname, physical_range="channelwise") - raw_read = read_raw_edf(temp_fname, preload=True) + raw_read = read_raw(temp_fname, preload=True) assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) -@edfio_mark() +@edf_params @pytest.mark.parametrize("pad_width", (1, 10, 100, 500, 999)) -def test_edf_padding(tmp_path, pad_width): +def test_edf_padding(tmp_path, pad_width, fmt): """Test exporting an EDF file with not-equal-length data blocks.""" ch_types = ["eeg"] * 4 ch_names = np.arange(len(ch_types)).astype(str).tolist() @@ -274,7 +287,7 @@ def test_edf_padding(tmp_path, pad_width): raw = RawArray(data, info) # export with physical range per channel type (default) - temp_fname = tmp_path / "test.edf" + temp_fname = tmp_path / f"test.{fmt}" with pytest.warns( RuntimeWarning, match=( @@ -285,7 +298,7 @@ def test_edf_padding(tmp_path, pad_width): raw.export(temp_fname) # read in the file - raw_read = read_raw_edf(temp_fname, preload=True) + raw_read = read_raw(temp_fname, preload=True) assert raw.n_times == raw_read.n_times - pad_width edge_data = raw_read.get_data()[:, -pad_width - 1] pad_data = raw_read.get_data()[:, -pad_width:] @@ -301,9 +314,9 @@ def test_edf_padding(tmp_path, pad_width): assert_array_almost_equal(raw_read.annotations.duration[0], pad_width / fs) -@edfio_mark() +@edf_params @pytest.mark.parametrize("tmin", (0, 0.005, 0.03, 1)) -def test_export_edf_annotations(tmp_path, tmin): +def test_export_edf_annotations(tmp_path, tmin, fmt): """Test annotations in the exported EDF file. All annotations should be preserved and onset corrected. @@ -327,12 +340,12 @@ def test_export_edf_annotations(tmp_path, tmin): ) # export - temp_fname = tmp_path / "test.edf" + temp_fname = tmp_path / f"test.{fmt}" with expectation: raw.export(temp_fname) # read in the file - raw_read = read_raw_edf(temp_fname, preload=True) + raw_read = read_raw(temp_fname, preload=True) assert raw_read.first_time == 0 # exportation resets first_time bad_annot = raw_read.annotations.description == "BAD_ACQ_SKIP" if bad_annot.any(): @@ -356,8 +369,8 @@ def test_export_edf_annotations(tmp_path, tmin): ) -@edfio_mark() -def test_rawarray_edf(tmp_path): +@edf_params +def test_rawarray_edf(tmp_path, fmt): """Test saving a Raw array with integer sfreq to EDF.""" raw = _create_raw_for_edf_tests() @@ -380,10 +393,10 @@ def test_rawarray_edf(tmp_path): tzinfo=timezone.utc, ) raw.set_meas_date(meas_date) - temp_fname = tmp_path / "test.edf" + temp_fname = tmp_path / f"test.{fmt}" raw.export(temp_fname, add_ch_type=True) - raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) + raw_read = read_raw(temp_fname, infer_types=True, preload=True) assert raw.ch_names == raw_read.ch_names assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) @@ -395,39 +408,39 @@ def test_rawarray_edf(tmp_path): assert raw.info["meas_date"] == raw_read.info["meas_date"] -@edfio_mark() -def test_edf_export_non_voltage_channels(tmp_path): +@edf_params +def test_edf_export_non_voltage_channels(tmp_path, fmt): """Test saving a Raw array containing a non-voltage channel.""" - temp_fname = tmp_path / "test.edf" + temp_fname = tmp_path / f"test.{fmt}" raw = _create_raw_for_edf_tests() raw.set_channel_types({"9": "hbr"}, on_unit_change="ignore") raw.export(temp_fname, overwrite=True) # data should match up to the non-accepted channel - raw_read = read_raw_edf(temp_fname, preload=True) + raw_read = read_raw(temp_fname, preload=True) assert raw.ch_names == raw_read.ch_names assert_array_almost_equal(raw.get_data()[:-1], raw_read.get_data()[:-1], decimal=10) assert_array_almost_equal(raw.get_data()[-1], raw_read.get_data()[-1], decimal=5) assert_array_equal(raw.times, raw_read.times) -@edfio_mark() -def test_channel_label_too_long_for_edf_raises_error(tmp_path): +@edf_params +def test_channel_label_too_long_for_edf_raises_error(tmp_path, fmt): """Test trying to save an EDF where a channel label is longer than 16 characters.""" raw = _create_raw_for_edf_tests() raw.rename_channels({"1": "abcdefghijklmnopqrstuvwxyz"}) with pytest.raises(RuntimeError, match="Signal label"): - raw.export(tmp_path / "test.edf") + raw.export(tmp_path / f"test.{fmt}") -@edfio_mark() -def test_measurement_date_outside_range_valid_for_edf(tmp_path): +@edf_params +def test_measurement_date_outside_range_valid_for_edf(tmp_path, fmt): """Test trying to save an EDF with a measurement date before 1985-01-01.""" raw = _create_raw_for_edf_tests() raw.set_meas_date(datetime(year=1984, month=1, day=1, tzinfo=timezone.utc)) - with pytest.raises(ValueError, match="EDF only allows dates from 1985 to 2084"): - raw.export(tmp_path / "test.edf", overwrite=True) + with pytest.raises(ValueError, match="DF only allows dates from 1985 to 2084"): + raw.export(tmp_path / f"test.{fmt}", overwrite=True) @pytest.mark.filterwarnings("ignore:Data has a non-integer:RuntimeWarning") @@ -438,33 +451,33 @@ def test_measurement_date_outside_range_valid_for_edf(tmp_path): ((0, 1e6), "minimum"), ], ) -@edfio_mark() -def test_export_edf_signal_clipping(tmp_path, physical_range, exceeded_bound): +@edf_params +def test_export_edf_signal_clipping(tmp_path, physical_range, exceeded_bound, fmt): """Test if exporting data exceeding physical min/max clips and emits a warning.""" raw = read_raw_fif(fname_raw) raw.pick(picks=["eeg", "ecog", "seeg"]).load_data() - temp_fname = tmp_path / "test.edf" + temp_fname = tmp_path / f"test.{fmt}" with ( _record_warnings(), pytest.warns(RuntimeWarning, match=f"The {exceeded_bound}"), ): raw.export(temp_fname, physical_range=physical_range) - raw_read = read_raw_edf(temp_fname, preload=True) + raw_read = read_raw(temp_fname, preload=True) assert raw_read.get_data().min() >= physical_range[0] assert raw_read.get_data().max() <= physical_range[1] -@edfio_mark() -def test_export_edf_with_constant_channel(tmp_path): +@edf_params +def test_export_edf_with_constant_channel(tmp_path, fmt): """Test if exporting to edf works if a channel contains only constant values.""" - temp_fname = tmp_path / "test.edf" + temp_fname = tmp_path / f"test.{fmt}" raw = RawArray(np.zeros((1, 10)), info=create_info(1, 1)) raw.export(temp_fname) - raw_read = read_raw_edf(temp_fname, preload=True) + raw_read = read_raw(temp_fname, preload=True) assert_array_equal(raw_read.get_data(), np.zeros((1, 10))) -@edfio_mark() +@edf_params @pytest.mark.parametrize( ("input_path", "warning_msg"), [ @@ -476,13 +489,13 @@ def test_export_edf_with_constant_channel(tmp_path): ), ], ) -def test_export_raw_edf(tmp_path, input_path, warning_msg): +def test_export_raw_edf(tmp_path, input_path, warning_msg, fmt): """Test saving a Raw instance to EDF format.""" raw = read_raw_fif(input_path) # only test with EEG channels raw.pick(picks=["eeg", "ecog", "seeg"]).load_data() - temp_fname = tmp_path / "test.edf" + temp_fname = tmp_path / f"test.{fmt}" with pytest.warns(RuntimeWarning, match=warning_msg): raw.export(temp_fname) @@ -490,7 +503,7 @@ def test_export_raw_edf(tmp_path, input_path, warning_msg): if "epoc" in raw.ch_names: raw.drop_channels(["epoc"]) - raw_read = read_raw_edf(temp_fname, preload=True) + raw_read = read_raw(temp_fname, preload=True) assert raw.ch_names == raw_read.ch_names # only compare the original length, since extra zeros are appended orig_raw_len = len(raw) @@ -513,8 +526,8 @@ def test_export_raw_edf(tmp_path, input_path, warning_msg): assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) -@edfio_mark() -def test_export_raw_edf_does_not_fail_on_empty_header_fields(tmp_path): +@edf_params +def test_export_raw_edf_does_not_fail_on_empty_header_fields(tmp_path, fmt): """Test writing a Raw instance with empty header fields to EDF.""" rng = np.random.RandomState(123456) @@ -531,7 +544,7 @@ def test_export_raw_edf_does_not_fail_on_empty_header_fields(tmp_path): data = rng.random(size=(len(ch_types), 1000)) * 1e-5 raw = RawArray(data, info) - raw.export(tmp_path / "test.edf", add_ch_type=True) + raw.export(tmp_path / f"test.{fmt}", add_ch_type=True) @pytest.mark.xfail(reason="eeglabio (usage?) bugs that should be fixed") diff --git a/mne/utils/check.py b/mne/utils/check.py index 085c51b6996..60a2322b9a3 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -446,9 +446,11 @@ def _check_eeglabio_installed(strict=True): return _soft_import("eeglabio", "exporting to EEGLab", strict=strict) -def _check_edfio_installed(strict=True): +def _check_edfio_installed(strict=True, *, min_version=None): """Aux function.""" - return _soft_import("edfio", "exporting to EDF", strict=strict) + return _soft_import( + "edfio", "exporting to EDF", min_version=min_version, strict=strict + ) def _check_pybv_installed(strict=True): diff --git a/tools/install_pre_requirements.sh b/tools/install_pre_requirements.sh index c717b1b477b..1c323ecfd6b 100755 --- a/tools/install_pre_requirements.sh +++ b/tools/install_pre_requirements.sh @@ -75,7 +75,7 @@ pip install $STD_ARGS git+https://github.com/joblib/joblib echo "edfio" # Disable protection for Azure, see # https://github.com/mne-tools/mne-python/pull/12609#issuecomment-2115639369 -GIT_CLONE_PROTECTION_ACTIVE=false pip install $STD_ARGS git+https://github.com/the-siesta-group/edfio +GIT_CLONE_PROTECTION_ACTIVE=false pip install $STD_ARGS "git+https://github.com/larsoner/edfio@bdf" echo "h5io" pip install $STD_ARGS git+https://github.com/h5io/h5io From 8020ea541a2aa0788f710ad681a6c755fd338c99 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 29 Jan 2025 13:26:41 -0500 Subject: [PATCH 2/4] DOC: Change --- doc/changes/newfeature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/newfeature.rst diff --git a/doc/changes/newfeature.rst b/doc/changes/newfeature.rst new file mode 100644 index 00000000000..6bee0d7052f --- /dev/null +++ b/doc/changes/newfeature.rst @@ -0,0 +1 @@ +Add support for exporting BDF files in :func:`mne.export.export_raw` by `Eric Larson`_. From 73cf0e821e5498b49cbf189c5a9beef09d9565ba Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 29 Jan 2025 16:02:04 -0500 Subject: [PATCH 3/4] FIX: Closer --- mne/export/_edf.py | 12 ++++++++++-- mne/export/tests/test_export.py | 2 +- mne/io/edf/edf.py | 13 ++++--------- mne/io/edf/tests/test_edf.py | 3 --- mne/utils/__init__.pyi | 4 ++++ mne/utils/numerics.py | 18 ++++++++++++++++++ mne/utils/tests/test_numerics.py | 15 +++++++++++++++ 7 files changed, 52 insertions(+), 15 deletions(-) diff --git a/mne/export/_edf.py b/mne/export/_edf.py index 65cbfc14980..c153ce11802 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -8,7 +8,7 @@ import numpy as np from ..annotations import _sync_onset -from ..utils import _check_edfio_installed, warn +from ..utils import _check_edfio_installed, check_version, warn _check_edfio_installed() @@ -44,7 +44,13 @@ def _export_raw(fname, raw, physical_range, add_ch_type, *, fmt="edf"): eeg="uV", ecog="uV", seeg="uV", eog="uV", ecg="uV", emg="uV", bio="uV", dbs="uV" ) - digital_min, digital_max = -32767, 32767 + if fmt == "edf": + digital_min, digital_max = -32768, 32767 # 2 ** 15 - 1 + else: + digital_min, digital_max = -8388608, 8388607 # 2 ** 23 - 1 + fmt_kwargs = dict() + if check_version("edfio", "0.4.6"): + fmt_kwargs["fmt"] = fmt annotations = [] # load data first @@ -157,6 +163,7 @@ def _export_raw(fname, raw, physical_range, add_ch_type, *, fmt="edf"): physical_range=prange, digital_range=(digital_min, digital_max), prefiltering=filter_str_info, + **fmt_kwargs, ) ) @@ -230,4 +237,5 @@ def _export_raw(fname, raw, physical_range, add_ch_type, *, fmt="edf"): starttime=starttime, data_record_duration=data_record_duration, annotations=annotations, + **fmt_kwargs, ).write(fname) diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 88cd05fb320..13e869f8d0a 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -234,7 +234,7 @@ def test_double_export_edf(tmp_path, fmt): raw_read = read_raw(temp_fname, infer_types=True, preload=True) assert raw.ch_names == raw_read.ch_names - assert_array_almost_equal(raw.get_data(), raw_read.get_data(), decimal=10) + assert_array_almost_equal(raw_read.get_data(), raw.get_data(), decimal=10) assert_array_equal(raw.times, raw_read.times) # check info diff --git a/mne/io/edf/edf.py b/mne/io/edf/edf.py index 09ac24f753e..37d0f27b45b 100644 --- a/mne/io/edf/edf.py +++ b/mne/io/edf/edf.py @@ -17,7 +17,7 @@ from ..._fiff.utils import _blk_read_lims, _mult_cal_one from ...annotations import Annotations from ...filter import resample -from ...utils import _validate_type, fill_doc, logger, verbose, warn +from ...utils import _read_24bit, _validate_type, fill_doc, logger, verbose, warn from ..base import BaseRaw, _get_scaling # common channel type names mapped to internal ch types @@ -333,16 +333,11 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): ) -def _read_ch(fid, subtype, samp, dtype_byte, dtype=None): +def _read_ch(fid, subtype, samp, *, dtype=None): """Read a number of samples for a single channel.""" # BDF if subtype == "bdf": - ch_data = np.fromfile(fid, dtype=dtype, count=samp * dtype_byte) - ch_data = ch_data.reshape(-1, 3).astype(INT32) - ch_data = (ch_data[:, 0]) + (ch_data[:, 1] << 8) + (ch_data[:, 2] << 16) - # 24th bit determines the sign - ch_data[ch_data >= (1 << 23)] -= 1 << 24 - + ch_data = _read_24bit(fid, samp) # GDF data and EDF data else: ch_data = np.fromfile(fid, dtype=dtype, count=samp) @@ -397,7 +392,7 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, fid.seek(start_offset + block_offset, 0) # Read and reshape to (n_chunks_read, ch0_ch1_ch2_ch3...) many_chunk = _read_ch( - fid, subtype, ch_offsets[-1] * n_read, dtype_byte, dtype + fid, subtype, ch_offsets[-1] * n_read, dtype=dtype ).reshape(n_read, -1) r_sidx = r_lims[ai][0] r_eidx = buf_len * (n_read - 1) + r_lims[ai + n_read - 1][1] diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index ce671ca7e81..d7ef98b8ca1 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -352,7 +352,6 @@ def test_parse_annotation(tmp_path): subtype="EDF", dtype="= 128 + data = data.view(" 0).any() + assert (np.abs(data) >= 2**16).any() # some that require 24-bit depth + fname = tmp_path / "test.24" + with fname.open("wb") as fid: + numerics._write_24bit(fid, data) + with fname.open("rb") as fid: + data_read = numerics._read_24bit(fid, len(data)) + assert_array_equal(data, data_read) From b7ef7b74645eb38bccc1dfc3e34c725dc468bd17 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 30 Jan 2025 09:24:06 -0500 Subject: [PATCH 4/4] FIX: Justify true zero --- mne/export/_edf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/export/_edf.py b/mne/export/_edf.py index c153ce11802..8bb7d2ee1f9 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -45,9 +45,9 @@ def _export_raw(fname, raw, physical_range, add_ch_type, *, fmt="edf"): ) if fmt == "edf": - digital_min, digital_max = -32768, 32767 # 2 ** 15 - 1 + digital_min, digital_max = -32767, 32767 # 2 ** 15 - 1, symmetric (true zero) else: - digital_min, digital_max = -8388608, 8388607 # 2 ** 23 - 1 + digital_min, digital_max = -8388607, 8388607 # 2 ** 23 - 1 fmt_kwargs = dict() if check_version("edfio", "0.4.6"): fmt_kwargs["fmt"] = fmt