Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/spikeinterface/curation/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer
from spikeinterface.core.generate import inject_some_split_units
from spikeinterface.curation import train_model
from pathlib import Path

job_kwargs = dict(n_jobs=-1)

Expand Down Expand Up @@ -82,6 +84,31 @@ def sorting_analyzer_with_splits():
return make_sorting_analyzer_with_splits(sorting_analyzer)


def make_trained_pipeline():
"""
Makes a model saved at "./trained_pipeline" which will be used by other tests in the module.
If the model already exists, this function does nothing.
"""
trained_model_folder = Path(__file__).parent / Path("trained_pipeline")
if not trained_model_folder.is_dir():
analyzer = make_sorting_analyzer(sparse=True)
analyzer.compute(
{
"quality_metrics": {"metric_names": ["snr", "num_spikes"]},
"template_metrics": {"metric_names": ["half_width"]},
}
)
train_model(
analyzers=[analyzer] * 5,
labels=[[1, 0, 1, 0, 1]] * 5,
folder=trained_model_folder,
classifiers=["RandomForestClassifier"],
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
)
return


if __name__ == "__main__":
sorting_analyzer = make_sorting_analyzer(sparse=False)
print(sorting_analyzer)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from pathlib import Path
from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation
from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, make_trained_pipeline
from spikeinterface.curation.model_based_curation import ModelBasedClassification
from spikeinterface.curation import auto_label_units, load_model
from spikeinterface.curation.train_manual_curation import _get_computed_metrics
Expand All @@ -19,6 +19,7 @@ def model():
It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with
the following labels: [1,0,1,0,1]."""

make_trained_pipeline()
model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"])
return model

Expand All @@ -45,6 +46,7 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, model):
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])

make_trained_pipeline()
model_folder = Path(__file__).parent / Path("trained_pipeline")

prediction_prob_dataframe_1 = auto_label_units(
Expand Down Expand Up @@ -148,6 +150,7 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
)
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])

make_trained_pipeline()
model_folder = Path(__file__).parent / Path("trained_pipeline")

model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"])
Expand Down
Binary file not shown.
21 changes: 0 additions & 21 deletions src/spikeinterface/curation/tests/trained_pipeline/labels.csv

This file was deleted.

This file was deleted.

60 changes: 0 additions & 60 deletions src/spikeinterface/curation/tests/trained_pipeline/model_info.json

This file was deleted.

This file was deleted.