-
Notifications
You must be signed in to change notification settings - Fork 99
add MajajHong2015-spatial_correlation benchmark and metric #2289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mschrimpf
wants to merge
3
commits into
brain-score:master
Choose a base branch
from
mschrimpf:mschrimpf/majajhong2015spatial
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
brainscore_vision/benchmarks/majajhong2015_spatial/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # developed in Schrimpf et al. 2024 https://www.biorxiv.org/content/10.1101/2024.01.09.572970 | ||
|
|
||
| from brainscore_vision import benchmark_registry | ||
|
|
||
| from .benchmark import DicarloMajajHong2015ITSpatialCorrelation | ||
|
|
||
| benchmark_registry['MajajHong2015.IT-spatial_correlation'] = DicarloMajajHong2015ITSpatialCorrelation |
83 changes: 83 additions & 0 deletions
83
brainscore_vision/benchmarks/majajhong2015_spatial/benchmark.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| from brainio.assemblies import NeuroidAssembly, walk_coords | ||
| from brainscore_core import Score | ||
|
|
||
| import brainscore_vision | ||
| from brainscore_vision.benchmarks import BenchmarkBase | ||
| from brainscore_vision.metrics.spatial_correlation.metric import inv_ks_similarity | ||
| from brainscore_vision.model_interface import BrainModel | ||
| from brainscore_vision.utils import LazyLoad | ||
| from ..majajhong2015.benchmark import BIBTEX | ||
|
|
||
| SPATIAL_BIN_SIZE_MM = .1 # .1 mm is an arbitrary choice | ||
|
|
||
|
|
||
| class DicarloMajajHong2015ITSpatialCorrelation(BenchmarkBase): | ||
| def __init__(self): | ||
| """ | ||
| This benchmark compares the distribution of pairwise response correlation as a function of distance between the | ||
| data recorded in Majaj* and Hong* et al. 2015 and a candidate model. | ||
| """ | ||
| self._assembly = self._load_assembly() | ||
| self._metric = brainscore_vision.load_metric( | ||
| 'spatial_correlation', | ||
| similarity_function=inv_ks_similarity, | ||
| bin_size_mm=SPATIAL_BIN_SIZE_MM, | ||
| num_bootstrap_samples=100_000, | ||
| num_sample_arrays=10) | ||
| ceiler = brainscore_vision.load_metric('inter_individual_helper', self._metric.compare_statistics) | ||
| target_statistic = LazyLoad(lambda: self._metric.compute_global_tissue_statistic_target(self._assembly)) | ||
| super().__init__(identifier='dicarlo.MajajHong2015.IT-spatial_correlation', | ||
| ceiling_func=lambda: ceiler(target_statistic), | ||
| version=1, | ||
| parent='IT', | ||
| bibtex=BIBTEX) | ||
|
|
||
| def _load_assembly(self) -> NeuroidAssembly: | ||
| assembly = brainscore_vision.load_dataset('MajajHong2015').sel(region='IT') | ||
| assembly = self.squeeze_time(assembly) | ||
| assembly = self.tissue_update(assembly) | ||
| return assembly | ||
|
|
||
| def __call__(self, candidate: BrainModel) -> Score: | ||
| """ | ||
| This computes the statistics, i.e. the pairwise response correlation of candidate and target, respectively and | ||
| computes a ceiling-normalized score based on the ks similarity of the two resulting distributions. | ||
| :param candidate: BrainModel | ||
| :return: average inverted ks similarity for the pairwise response correlation compared to the MajajHong assembly | ||
| """ | ||
| candidate.start_recording(recording_target=BrainModel.RecordingTarget.IT, | ||
| time_bins=[(70, 170)], | ||
| # "we implanted each monkey with three arrays in the left cerebral hemisphere" | ||
| ) | ||
| candidate_assembly = candidate.look_at(self._assembly.stimulus_set) | ||
| candidate_assembly = self.squeeze_time(candidate_assembly) | ||
|
|
||
| raw_score = self._metric(candidate_assembly, self._assembly) | ||
| score = raw_score / self.ceiling | ||
| score.attrs['raw'] = raw_score | ||
| score.attrs['ceiling'] = self.ceiling | ||
| return score | ||
|
|
||
| @staticmethod | ||
| def tissue_update(assembly): | ||
| """ | ||
| The current MajajHong2015 assembly has x and y coordinates of each array electrode stored as | ||
| coordinates `x` and `y` rather than the preferred `tissue_x` and `tissue_y`. Add these coordinates here. | ||
| """ | ||
| if not hasattr(assembly, 'tissue_x'): | ||
| assembly['tissue_x'] = assembly['x'] | ||
| assembly['tissue_y'] = assembly['y'] | ||
| # re-index | ||
| attrs = assembly.attrs | ||
| assembly = type(assembly)(assembly.values, coords={ | ||
| coord: (dims, values) for coord, dims, values in walk_coords(assembly)}, dims=assembly.dims) | ||
| assembly.attrs = attrs | ||
| return assembly | ||
|
|
||
| @staticmethod | ||
| def squeeze_time(assembly): | ||
| if 'time_bin' in assembly.dims: | ||
| assembly = assembly.squeeze('time_bin') | ||
| if hasattr(assembly, "time_step"): | ||
| assembly = assembly.squeeze("time_step") | ||
| return assembly | ||
13 changes: 13 additions & 0 deletions
13
brainscore_vision/benchmarks/majajhong2015_spatial/test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| import pytest | ||
| from pytest import approx | ||
|
|
||
| import brainscore_vision | ||
| from brainscore_vision.benchmark_helpers import PrecomputedFeatures | ||
|
|
||
|
|
||
| def test_benchmark_runs(): | ||
| benchmark = brainscore_vision.load_benchmark('MajajHong2015.IT-spatial_correlation') | ||
| source = benchmark._assembly.copy() | ||
| source = {benchmark._assembly.stimulus_set.identifier: source} | ||
| features = PrecomputedFeatures(source, visual_degrees=8) | ||
| benchmark(features) |
4 changes: 4 additions & 0 deletions
4
brainscore_vision/metrics/inter_individual_stats_ceiling/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from brainscore_vision import metric_registry | ||
| from .ceiling import InterIndividualStatisticsCeiling | ||
|
|
||
| metric_registry['inter_individual_helper'] = InterIndividualStatisticsCeiling |
42 changes: 42 additions & 0 deletions
42
brainscore_vision/metrics/inter_individual_stats_ceiling/ceiling.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| from xarray import DataArray | ||
|
|
||
| from brainscore_vision.metric_helpers.transformations import apply_aggregate | ||
| from brainscore_vision import Ceiling | ||
| from brainscore_vision import Score | ||
|
|
||
|
|
||
| class InterIndividualStatisticsCeiling(Ceiling): | ||
| """ | ||
| Cross-validation-like, animal-wise computation of ceiling | ||
| """ | ||
|
|
||
| def __init__(self, metric): | ||
| """ | ||
| :param metric: used to compute the ceiling | ||
| """ | ||
| self._metric = metric | ||
|
|
||
| def __call__(self, statistic: DataArray) -> Score: | ||
| """ | ||
| Applies given metric to dataset, comparing data from one animal to all remaining animals, i.e.: | ||
| For each animal: metric({dataset\animal_i}, animal_i): cross validation like | ||
| :param statistic: xarray structure with values & and corresponding meta information: distances, source | ||
| :return: ceiling | ||
| """ | ||
| assert len(set(statistic.source.data)) > 1, 'your stats contain less than 2 subjects' | ||
| self.statistic = statistic | ||
|
|
||
| monkey_scores = [] | ||
| for heldout_monkey in sorted(set(self.statistic.source.data)): | ||
| monkey_pool = self.statistic.where(self.statistic.source != heldout_monkey, drop=True) | ||
| heldout = self.statistic.sel(source=heldout_monkey) | ||
| score = self._metric(monkey_pool, heldout) | ||
|
|
||
| score = score.expand_dims('monkey') | ||
| score['monkey'] = [heldout_monkey] | ||
| monkey_scores.append(score) | ||
| # aggregate | ||
| scores = Score.merge(*monkey_scores) | ||
| return apply_aggregate(lambda s: s.mean('monkey'), scores) | ||
|
|
||
| # Note: this should probably be more general for arbitrary subjects instead of 'monkey' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # developed in Schrimpf et al. 2024 https://www.biorxiv.org/content/10.1101/2024.01.09.572970 | ||
|
|
||
| from brainscore_vision import metric_registry | ||
| from .metric import SpatialCorrelationSimilarity | ||
|
|
||
| metric_registry['spatial_correlation'] = SpatialCorrelationSimilarity |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.