Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainscore_vision/benchmark_helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def visual_degrees(self) -> int:
def start_task(self, task, fitting_stimuli=None):
pass

def start_recording(self, region, *args, **kwargs):
def start_recording(self, recording_target: BrainModel.RecordingTarget, *args, **kwargs):
pass

def look_at(self, stimuli, number_of_trials=1):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# developed in Schrimpf et al. 2024 https://www.biorxiv.org/content/10.1101/2024.01.09.572970

from brainscore_vision import benchmark_registry

from .benchmark import DicarloMajajHong2015ITSpatialCorrelation

benchmark_registry['MajajHong2015.IT-spatial_correlation'] = DicarloMajajHong2015ITSpatialCorrelation
83 changes: 83 additions & 0 deletions brainscore_vision/benchmarks/majajhong2015_spatial/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from brainscore_core.supported_data_standards.brainio.assemblies import NeuroidAssembly, walk_coords
from brainscore_core import Score

import brainscore_vision
from brainscore_vision.benchmarks import BenchmarkBase
from brainscore_vision.metrics.spatial_correlation.metric import inv_ks_similarity
from brainscore_vision.model_interface import BrainModel
from brainscore_vision.utils import LazyLoad
from ..majajhong2015.benchmark import BIBTEX

SPATIAL_BIN_SIZE_MM = .1 # .1 mm is an arbitrary choice


class DicarloMajajHong2015ITSpatialCorrelation(BenchmarkBase):
def __init__(self):
"""
This benchmark compares the distribution of pairwise response correlation as a function of distance between the
data recorded in Majaj* and Hong* et al. 2015 and a candidate model.
"""
self._assembly = self._load_assembly()
self._metric = brainscore_vision.load_metric(
'spatial_correlation',
similarity_function=inv_ks_similarity,
bin_size_mm=SPATIAL_BIN_SIZE_MM,
num_bootstrap_samples=100_000,
num_sample_arrays=10)
ceiler = brainscore_vision.load_metric('inter_individual_helper', self._metric.compare_statistics)
target_statistic = LazyLoad(lambda: self._metric.compute_global_tissue_statistic_target(self._assembly))
super().__init__(identifier='dicarlo.MajajHong2015.IT-spatial_correlation',
ceiling_func=lambda: ceiler(target_statistic),
version=1,
parent='IT',
bibtex=BIBTEX)

def _load_assembly(self) -> NeuroidAssembly:
assembly = brainscore_vision.load_dataset('MajajHong2015').sel(region='IT')
assembly = self.squeeze_time(assembly)
assembly = self.tissue_update(assembly)
return assembly

def __call__(self, candidate: BrainModel) -> Score:
"""
This computes the statistics, i.e. the pairwise response correlation of candidate and target, respectively and
computes a ceiling-normalized score based on the ks similarity of the two resulting distributions.
:param candidate: BrainModel
:return: average inverted ks similarity for the pairwise response correlation compared to the MajajHong assembly
"""
candidate.start_recording(recording_target=BrainModel.RecordingTarget.IT,
time_bins=[(70, 170)],
# "we implanted each monkey with three arrays in the left cerebral hemisphere"
)
candidate_assembly = candidate.look_at(self._assembly.stimulus_set)
candidate_assembly = self.squeeze_time(candidate_assembly)

raw_score = self._metric(candidate_assembly, self._assembly)
score = raw_score / self.ceiling
score.attrs['raw'] = raw_score
score.attrs['ceiling'] = self.ceiling
return score

@staticmethod
def tissue_update(assembly):
"""
The current MajajHong2015 assembly has x and y coordinates of each array electrode stored as
coordinates `x` and `y` rather than the preferred `tissue_x` and `tissue_y`. Add these coordinates here.
"""
if not hasattr(assembly, 'tissue_x'):
assembly['tissue_x'] = assembly['x']
assembly['tissue_y'] = assembly['y']
# re-index
attrs = assembly.attrs
assembly = type(assembly)(assembly.values, coords={
coord: (dims, values) for coord, dims, values in walk_coords(assembly)}, dims=assembly.dims)
assembly.attrs = attrs
return assembly

@staticmethod
def squeeze_time(assembly):
if 'time_bin' in assembly.dims:
assembly = assembly.squeeze('time_bin')
if hasattr(assembly, "time_step"):
assembly = assembly.squeeze("time_step")
return assembly
13 changes: 13 additions & 0 deletions brainscore_vision/benchmarks/majajhong2015_spatial/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
from pytest import approx

import brainscore_vision
from brainscore_vision.benchmark_helpers import PrecomputedFeatures


def test_benchmark_runs():
benchmark = brainscore_vision.load_benchmark('MajajHong2015.IT-spatial_correlation')
source = benchmark._assembly.copy()
source = {benchmark._assembly.stimulus_set.identifier: source}
features = PrecomputedFeatures(source, visual_degrees=8)
benchmark(features)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from brainscore_vision import metric_registry
from .ceiling import InterIndividualStatisticsCeiling

metric_registry['inter_individual_helper'] = InterIndividualStatisticsCeiling
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from xarray import DataArray

from brainscore_vision.metric_helpers.transformations import apply_aggregate
from brainscore_vision import Ceiling
from brainscore_vision import Score


class InterIndividualStatisticsCeiling(Ceiling):
"""
Cross-validation-like, animal-wise computation of ceiling
"""

def __init__(self, metric):
"""
:param metric: used to compute the ceiling
"""
self._metric = metric

def __call__(self, statistic: DataArray) -> Score:
"""
Applies given metric to dataset, comparing data from one animal to all remaining animals, i.e.:
For each animal: metric({dataset\animal_i}, animal_i): cross validation like
:param statistic: xarray structure with values & and corresponding meta information: distances, source
:return: ceiling
"""
assert len(set(statistic.source.data)) > 1, 'your stats contain less than 2 subjects'
self.statistic = statistic

monkey_scores = []
for heldout_monkey in sorted(set(self.statistic.source.data)):
monkey_pool = self.statistic.where(self.statistic.source != heldout_monkey, drop=True)
heldout = self.statistic.sel(source=heldout_monkey)
score = self._metric(monkey_pool, heldout)

score = score.expand_dims('monkey')
score['monkey'] = [heldout_monkey]
monkey_scores.append(score)
# aggregate
scores = Score.merge(*monkey_scores)
return apply_aggregate(lambda s: s.mean('monkey'), scores)

# Note: this should probably be more general for arbitrary subjects instead of 'monkey'
6 changes: 6 additions & 0 deletions brainscore_vision/metrics/spatial_correlation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# developed in Schrimpf et al. 2024 https://www.biorxiv.org/content/10.1101/2024.01.09.572970

from brainscore_vision import metric_registry
from .metric import SpatialCorrelationSimilarity

metric_registry['spatial_correlation'] = SpatialCorrelationSimilarity
Loading
Loading