Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions brainscore_vision/benchmarks/muzellec2026/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from brainscore_vision import benchmark_registry
from .benchmark import MajajHongV4PublicBenchmark, MajajHongITPublicBenchmark

benchmark_registry['MajajHong2015public.V4-reverse_pls'] = MajajHongV4PublicBenchmark
benchmark_registry['MajajHong2015public.IT-reverse_pls'] = MajajHongITPublicBenchmark

82 changes: 82 additions & 0 deletions brainscore_vision/benchmarks/muzellec2026/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from brainscore_core import Metric
from brainscore_vision import load_metric, load_dataset
from brainscore_vision.benchmark_helpers.neural_common import NeuralBenchmark, average_repetition

import brainscore_vision.metrics.predictor_consistency
from brainscore_vision.metrics.predictor_consistency.ceiling import SplitHalfPredictorConsistency

VISUAL_DEGREES = 8
NUMBER_OF_TRIALS = 50
BIBTEX = """@article{muzellec2025reverse,
title={Reverse Predictivity: Going Beyond One-Way Mapping to Compare Artificial Neural Network Models and Brains},
author={Muzellec, Sabine and Kar, Kohitij},
journal={bioRxiv},
pages={2025--08},
year={2025},
publisher={Cold Spring Harbor Laboratory}
}"""

crossvalidation_kwargs = dict(stratification_coord="object_name")

MY_METRIC_ID = "reverse_pls"


def _MajajHong2015PublicRegion(
region: str,
identifier_metric_suffix: str,
similarity_metric: Metric,
):
assembly_repetition = load_assembly(average_repetitions=False, region=region, access="public")
assembly = load_assembly(average_repetitions=True, region=region, access="public")
# print(assembly_repetition)
benchmark_identifier = f"MajajHong2015public.{region}"

predictor_ceiler = SplitHalfPredictorConsistency()

return NeuralBenchmark(
identifier=f"{benchmark_identifier}-{identifier_metric_suffix}",
version=4,
assembly=assembly,
similarity_metric=similarity_metric,
visual_degrees=VISUAL_DEGREES,
number_of_trials=NUMBER_OF_TRIALS,
ceiling_func=lambda: predictor_ceiler(assembly_repetition),
parent=region,
bibtex=BIBTEX,
)


def MajajHongV4PublicBenchmark():
similarity_metric = load_metric(MY_METRIC_ID, crossvalidation_kwargs=crossvalidation_kwargs)
return _MajajHong2015PublicRegion(
region="V4",
identifier_metric_suffix=MY_METRIC_ID,
similarity_metric=similarity_metric,
)


def MajajHongITPublicBenchmark():
similarity_metric = load_metric(MY_METRIC_ID, crossvalidation_kwargs=crossvalidation_kwargs)
return _MajajHong2015PublicRegion(
region="IT",
identifier_metric_suffix=MY_METRIC_ID,
similarity_metric=similarity_metric,
)


def load_assembly(average_repetitions: bool, region: str, access: str = "public"):
assembly = load_dataset(f"MajajHong2015.{access}")

if "time_bin" in assembly.dims:
assembly = assembly.squeeze("time_bin")

assembly = assembly.sel(region=region)
assembly["region"] = ("neuroid", [region] * len(assembly["neuroid"]))

assembly.load()
assembly = assembly.transpose("presentation", "neuroid", ...)

if average_repetitions:
assembly = average_repetition(assembly)

return assembly
64 changes: 64 additions & 0 deletions brainscore_vision/benchmarks/muzellec2026/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
benchmarks:

MajajHong2015public.IT-reverse_pls:
stimulus_set:
num_stimuli: 2560
datatype: image
stimuli_subtype: null
total_size_mb: 2.1
brainscore_link: https://github.com/brain-score/vision/tree/master/brainscore_vision/data/majajhong2015
extra_notes: null

data:
benchmark_type: neural
task: null
region: IT
hemisphere: L
num_recording_sites: null
duration_ms: null
species: macaque
datatype: spike_rate
num_subjects: null
pre_processing: trial-averaged responses (per benchmark definition)
brainscore_link: https://github.com/brain-score/vision/tree/master/brainscore_vision/data/majajhong2015
extra_notes: Reverse benchmark implemented in brainscore_vision/benchmarks/muzellec2026
data_publicly_available: true

metric:
type: reverse_pls
reference: null
public: true
brainscore_link: https://github.com/brain-score/vision/tree/master/brainscore_vision/metrics/regression_correlation
extra_notes: Neural-to-model reverse predictivity using PLS regression


MajajHong2015public.V4-reverse_pls:
stimulus_set:
num_stimuli: 2560
datatype: image
stimuli_subtype: null
total_size_mb: 2.1
brainscore_link: https://github.com/brain-score/vision/tree/master/brainscore_vision/data/majajhong2015
extra_notes: null

data:
benchmark_type: neural
task: null
region: V4
hemisphere: L
num_recording_sites: null
duration_ms: null
species: macaque
datatype: spike_rate
num_subjects: null
pre_processing: trial-averaged responses (per benchmark definition)
brainscore_link: https://github.com/brain-score/vision/tree/master/brainscore_vision/data/majajhong2015
extra_notes: Reverse benchmark implemented in brainscore_vision/benchmarks/muzellec2026
data_publicly_available: true

metric:
type: reverse_pls
reference: null
public: true
brainscore_link: https://github.com/brain-score/vision/tree/master/brainscore_vision/metrics/regression_correlation
extra_notes: Neural-to-model reverse predictivity using PLS regression
80 changes: 80 additions & 0 deletions brainscore_vision/benchmarks/muzellec2026/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
from pytest import approx

from brainscore_vision import benchmark_registry, load_benchmark
from brainscore_vision.benchmark_helpers import PrecomputedFeatures
from brainscore_vision.benchmark_helpers.test_helper import NumberOfTrialsTests

from brainscore_vision.benchmarks.muzellec2026.benchmark import (
MajajHongV4PublicBenchmark,
MajajHongITPublicBenchmark,
)

num_trials_test = NumberOfTrialsTests()

BENCHMARKS = [
"MajajHong2015public.V4-reverse_pls",
"MajajHong2015public.IT-reverse_pls",
]


@pytest.mark.parametrize("benchmark_identifier", BENCHMARKS)
def test_registered(benchmark_identifier):
# Note: registry may not be populated until the benchmarks package is imported.
assert benchmark_identifier in benchmark_registry


@pytest.mark.parametrize("benchmark_ctor", [MajajHongV4PublicBenchmark, MajajHongITPublicBenchmark])
def test_constructs(benchmark_ctor):
b = benchmark_ctor()
assert b is not None
assert hasattr(b, "_assembly")
assert b._assembly is not None


@pytest.mark.parametrize("benchmark_identifier", BENCHMARKS)
def test_repetitions_metadata(benchmark_identifier):
# This checks that benchmark.number_of_trials etc are consistent with the assembly structure
# and will fail if the benchmark lost repetition information needed for trial averaging.
num_trials_test.repetitions_test(benchmark_identifier)


@pytest.mark.parametrize("benchmark_identifier", BENCHMARKS)
def test_ceiling_smoke_and_metadata(benchmark_identifier):
b = load_benchmark(benchmark_identifier)

c_direct = b._ceiling_func()
assert float(c_direct) == c_direct.values # scalar Score/DataArray
assert c_direct.attrs.get("ceiling") == "predictor_consistency_split_half"
assert "error" in c_direct.attrs
assert float(c_direct.attrs["error"]) >= 0


@pytest.mark.parametrize("benchmark_identifier", BENCHMARKS)
def test_ceiling_cached_matches_direct(benchmark_identifier):
"""
Requires: you bumped version in benchmark.py after changing ceiling_func.
Otherwise cached value can reflect an older ceiling.
"""
b = load_benchmark(benchmark_identifier)
c_direct = b._ceiling_func()
c_cached = b.ceiling

assert float(c_cached) == approx(float(c_direct), abs=3e-3)


@pytest.mark.parametrize("benchmark_identifier", BENCHMARKS)
def test_self_scoring_runs(benchmark_identifier):
"""
Self-score: feed the benchmark its own assembly as "model features".
Should be finite and should attach raw/ceiling attrs.
"""
b = load_benchmark(benchmark_identifier)
src = b._assembly.copy()
src = {b._assembly.stimulus_set.identifier: src}
score = b(PrecomputedFeatures(src, visual_degrees=8))

assert score is not None
assert float(score) == float(score.values)
assert score.attrs.get("ceiling") is not None
assert score.attrs.get("raw") is not None
4 changes: 4 additions & 0 deletions brainscore_vision/metrics/predictor_consistency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from brainscore_vision import metric_registry
from .ceiling import SplitHalfPredictorConsistency

metric_registry["predictor_consistency"] = lambda *a, **k: SplitHalfPredictorConsistency()
106 changes: 106 additions & 0 deletions brainscore_vision/metrics/predictor_consistency/ceiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numpy as np
import xarray as xr
from brainscore_core.metrics import Score
from brainscore_vision.metrics import Ceiling

def _get_presentation_level(X, name: str):
"""
Return a 1D array aligned to 'presentation' for `name`, whether it is:
- a real coordinate on 'presentation', or
- a level of the presentation MultiIndex.
"""
if name in X.coords and "presentation" in X[name].dims:
return np.asarray(X[name].values)

if "presentation" in X.dims and "presentation" in X.indexes:
idx = X.indexes["presentation"]
if hasattr(idx, "names") and name in idx.names:
return np.asarray(idx.get_level_values(name))

raise ValueError(
f"Could not find '{name}' as coord-on-presentation or MultiIndex level. "
f"dims={X.dims} coords={list(X.coords)} "
f"presentation_index_names={getattr(X.indexes.get('presentation', None), 'names', None)}"
)

def _pearson_per_neuroid(a, b):
a = a - a.mean(axis=0, keepdims=True)
b = b - b.mean(axis=0, keepdims=True)
denom = np.linalg.norm(a, axis=0) * np.linalg.norm(b, axis=0)
return (a * b).sum(axis=0) / np.where(denom == 0, np.nan, denom)

def _spearman_brown(r):
return (2 * r) / (1 + r)

class SplitHalfPredictorConsistency(Ceiling):
def __init__(self, n_splits=10, seed=0, image_level="image_id", rep_level="repetition"):
self.n_splits = n_splits
self.seed = seed
self.image_level = image_level
self.rep_level = rep_level

def __call__(self, X) -> Score:
if "time_bin" in X.dims and X.sizes.get("time_bin", 1) == 1:
X = X.squeeze("time_bin")
X = X.transpose("presentation", "neuroid")

reps = _get_presentation_level(X, self.rep_level).astype(int)
imgs = _get_presentation_level(X, self.image_level)

unique_imgs, inv = np.unique(imgs, return_inverse=True)
rng = np.random.default_rng(self.seed)

split_rs = []
for _ in range(self.n_splits):
A_means, B_means = [], []
kept = 0

for img_idx in range(len(unique_imgs)):
idx = np.where(inv == img_idx)[0]
if idx.size < 2:
continue

rvals = reps[idx]
uniq_r = np.unique(rvals)
if uniq_r.size < 2:
continue

perm = rng.permutation(uniq_r)
half = uniq_r.size // 2
A_r, B_r = perm[:half], perm[half:]
if A_r.size == 0 or B_r.size == 0:
continue

A_idx = idx[np.isin(rvals, A_r)]
B_idx = idx[np.isin(rvals, B_r)]
if A_idx.size == 0 or B_idx.size == 0:
continue

A_means.append(X.isel(presentation=A_idx).mean("presentation").values)
B_means.append(X.isel(presentation=B_idx).mean("presentation").values)
kept += 1

if kept < 10:
split_rs.append(np.full((X.sizes["neuroid"],), np.nan))
continue

A = np.stack(A_means, axis=0)
B = np.stack(B_means, axis=0)

r = _pearson_per_neuroid(A, B)
split_rs.append(_spearman_brown(r))

split_rs = np.stack(split_rs, axis=0) # (split, neuroid)
per_split = np.nanmedian(split_rs, axis=1)
value = np.nanmean(per_split)
err = np.nanstd(per_split) / np.sqrt(np.sum(np.isfinite(per_split)))

score = Score(value)
score.attrs["raw"] = xr.DataArray(
split_rs,
dims=("split", "neuroid"),
coords={"split": np.arange(self.n_splits), "neuroid": X["neuroid"].values},
)
score.attrs["error"] = Score(err)
score.attrs["ceiling"] = "predictor_consistency_split_half"
return score
20 changes: 19 additions & 1 deletion brainscore_vision/metrics/regression_correlation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from brainscore_vision import metric_registry
from .metric import CrossRegressedCorrelation, pls_regression, ridge_cv_regression, ridge_regression, single_regression, linear_regression,\
pearsonr_correlation
pearsonr_correlation, ReverseCrossRegressedCorrelation, ReverseTrainTestSplitCorrelation


#metrics using cross-validation to generate multiple train-test splits from a monolithic dataset

Expand All @@ -26,13 +27,21 @@
metric_registry['ridgecv_split'] = lambda *args, **kwargs: TrainTestSplitCorrelation(
regression=ridge_cv_regression(**kwargs), correlation=pearsonr_correlation(), *args, **kwargs)

metric_registry["reverse_pls_cv"] = lambda *args, **kwargs: ReverseCrossRegressedCorrelation(
regression=pls_regression(), correlation=pearsonr_correlation(), *args, **kwargs)

metric_registry['reverse_pls_split'] = lambda *args, **kwargs: ReverseTrainTestSplitCorrelation(
regression=pls_regression(), correlation=pearsonr_correlation(), *args, **kwargs)


#backwards compatibility
metric_registry['pls'] = metric_registry['pls_cv']
metric_registry['ridge'] = metric_registry['ridge_cv']
metric_registry['neuron_to_neuron'] = metric_registry['neuron_to_neuron_cv']
metric_registry['linear_predictivity'] = metric_registry['linear_predictivity_cv']

metric_registry['reverse_pls'] = metric_registry['reverse_pls_cv']


# temporal metrics
from .metric import SpanTimeCrossRegressedCorrelation
Expand Down Expand Up @@ -69,3 +78,12 @@
year={2018},
institution={Center for Brains, Minds and Machines (CBMM)}
}"""

BIBTEX_REVERSE_PLS = """@article{muzellec2025reverse,
title={Reverse Predictivity: Going Beyond One-Way Mapping to Compare Artificial Neural Network Models and Brains},
author={Muzellec, Sabine and Kar, Kohitij},
journal={bioRxiv},
pages={2025--08},
year={2025},
publisher={Cold Spring Harbor Laboratory}
}"""
Loading
Loading