Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainscore_vision/benchmark_helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def visual_degrees(self) -> int:
def start_task(self, task, fitting_stimuli=None):
pass

def start_recording(self, region, *args, **kwargs):
def start_recording(self, recording_target: BrainModel.RecordingTarget, *args, **kwargs):
pass

def look_at(self, stimuli, number_of_trials=1):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# developed in Schrimpf et al. 2024 https://www.biorxiv.org/content/10.1101/2024.01.09.572970

from brainscore_vision import benchmark_registry

from .benchmark import DicarloMajajHong2015ITSpatialCorrelation

benchmark_registry['MajajHong2015.IT-spatial_correlation'] = DicarloMajajHong2015ITSpatialCorrelation
83 changes: 83 additions & 0 deletions brainscore_vision/benchmarks/majajhong2015_spatial/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from brainio.assemblies import NeuroidAssembly, walk_coords
from brainscore_core import Score

import brainscore_vision
from brainscore_vision.benchmarks import BenchmarkBase
from brainscore_vision.metrics.spatial_correlation.metric import inv_ks_similarity
from brainscore_vision.model_interface import BrainModel
from brainscore_vision.utils import LazyLoad
from ..majajhong2015.benchmark import BIBTEX

SPATIAL_BIN_SIZE_MM = .1 # .1 mm is an arbitrary choice


class DicarloMajajHong2015ITSpatialCorrelation(BenchmarkBase):
def __init__(self):
"""
This benchmark compares the distribution of pairwise response correlation as a function of distance between the
data recorded in Majaj* and Hong* et al. 2015 and a candidate model.
"""
self._assembly = self._load_assembly()
self._metric = brainscore_vision.load_metric(
'spatial_correlation',
similarity_function=inv_ks_similarity,
bin_size_mm=SPATIAL_BIN_SIZE_MM,
num_bootstrap_samples=100_000,
num_sample_arrays=10)
ceiler = brainscore_vision.load_metric('inter_individual_helper', self._metric.compare_statistics)
target_statistic = LazyLoad(lambda: self._metric.compute_global_tissue_statistic_target(self._assembly))
super().__init__(identifier='dicarlo.MajajHong2015.IT-spatial_correlation',
ceiling_func=lambda: ceiler(target_statistic),
version=1,
parent='IT',
bibtex=BIBTEX)

def _load_assembly(self) -> NeuroidAssembly:
assembly = brainscore_vision.load_dataset('MajajHong2015').sel(region='IT')
assembly = self.squeeze_time(assembly)
assembly = self.tissue_update(assembly)
return assembly

def __call__(self, candidate: BrainModel) -> Score:
"""
This computes the statistics, i.e. the pairwise response correlation of candidate and target, respectively and
computes a ceiling-normalized score based on the ks similarity of the two resulting distributions.
:param candidate: BrainModel
:return: average inverted ks similarity for the pairwise response correlation compared to the MajajHong assembly
"""
candidate.start_recording(recording_target=BrainModel.RecordingTarget.IT,
time_bins=[(70, 170)],
# "we implanted each monkey with three arrays in the left cerebral hemisphere"
)
candidate_assembly = candidate.look_at(self._assembly.stimulus_set)
candidate_assembly = self.squeeze_time(candidate_assembly)

raw_score = self._metric(candidate_assembly, self._assembly)
score = raw_score / self.ceiling
score.attrs['raw'] = raw_score
score.attrs['ceiling'] = self.ceiling
return score

@staticmethod
def tissue_update(assembly):
"""
The current MajajHong2015 assembly has x and y coordinates of each array electrode stored as
coordinates `x` and `y` rather than the preferred `tissue_x` and `tissue_y`. Add these coordinates here.
"""
if not hasattr(assembly, 'tissue_x'):
assembly['tissue_x'] = assembly['x']
assembly['tissue_y'] = assembly['y']
# re-index
attrs = assembly.attrs
assembly = type(assembly)(assembly.values, coords={
coord: (dims, values) for coord, dims, values in walk_coords(assembly)}, dims=assembly.dims)
assembly.attrs = attrs
return assembly

@staticmethod
def squeeze_time(assembly):
if 'time_bin' in assembly.dims:
assembly = assembly.squeeze('time_bin')
if hasattr(assembly, "time_step"):
assembly = assembly.squeeze("time_step")
return assembly
13 changes: 13 additions & 0 deletions brainscore_vision/benchmarks/majajhong2015_spatial/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
from pytest import approx

import brainscore_vision
from brainscore_vision.benchmark_helpers import PrecomputedFeatures


def test_benchmark_runs():
benchmark = brainscore_vision.load_benchmark('MajajHong2015.IT-spatial_correlation')
source = benchmark._assembly.copy()
source = {benchmark._assembly.stimulus_set.identifier: source}
features = PrecomputedFeatures(source, visual_degrees=8)
benchmark(features)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from brainscore_vision import metric_registry
from .ceiling import InterIndividualStatisticsCeiling

metric_registry['inter_individual_helper'] = InterIndividualStatisticsCeiling
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from xarray import DataArray

from brainscore_vision.metric_helpers.transformations import apply_aggregate
from brainscore_vision import Ceiling
from brainscore_vision import Score


class InterIndividualStatisticsCeiling(Ceiling):
"""
Cross-validation-like, animal-wise computation of ceiling
"""

def __init__(self, metric):
"""
:param metric: used to compute the ceiling
"""
self._metric = metric

def __call__(self, statistic: DataArray) -> Score:
"""
Applies given metric to dataset, comparing data from one animal to all remaining animals, i.e.:
For each animal: metric({dataset\animal_i}, animal_i): cross validation like
:param statistic: xarray structure with values & and corresponding meta information: distances, source
:return: ceiling
"""
assert len(set(statistic.source.data)) > 1, 'your stats contain less than 2 subjects'
self.statistic = statistic

monkey_scores = []
for heldout_monkey in sorted(set(self.statistic.source.data)):
monkey_pool = self.statistic.where(self.statistic.source != heldout_monkey, drop=True)
heldout = self.statistic.sel(source=heldout_monkey)
score = self._metric(monkey_pool, heldout)

score = score.expand_dims('monkey')
score['monkey'] = [heldout_monkey]
monkey_scores.append(score)
# aggregate
scores = Score.merge(*monkey_scores)
return apply_aggregate(lambda s: s.mean('monkey'), scores)

# Note: this should probably be more general for arbitrary subjects instead of 'monkey'
6 changes: 6 additions & 0 deletions brainscore_vision/metrics/spatial_correlation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# developed in Schrimpf et al. 2024 https://www.biorxiv.org/content/10.1101/2024.01.09.572970

from brainscore_vision import metric_registry
from .metric import SpatialCorrelationSimilarity

metric_registry['spatial_correlation'] = SpatialCorrelationSimilarity
Loading
Loading