diff --git a/tests/distillation/feature_extraction_test.py b/tests/distillation/feature_extraction_test.py new file mode 100644 index 00000000..b7fe875b --- /dev/null +++ b/tests/distillation/feature_extraction_test.py @@ -0,0 +1,276 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for distillation feature extraction helpers. + +Covers: +- Avg-pooling array utilities (VALID vs SAME and pad-count behavior) +- Sowed module wrap/pop/unwrap behavior +- Feature projection setup/removal integration +""" + +from __future__ import annotations +from absl.testing import absltest +from absl.testing import parameterized +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np + +from tunix.distillation import feature_extraction +from tunix.distillation.feature_extraction import pooling +from tunix.distillation.feature_extraction import sowed_module + +class _FeatureLayer(nnx.Module): + """A tiny feature block used in toy models for unit testing.""" + + def __init__(self, in_dim: int, feat_dim: int, *, rngs: nnx.Rngs): + self._proj = nnx.Linear(in_dim, feat_dim, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return jax.nn.relu(self._proj(x)) + + +class _ToyClassifier(nnx.Module): + """Toy classifier: FeatureLayer -> Linear head.""" + def __init__( + self, + in_dim: int, + feat_dim: int, + num_classes: int, + *, + rngs: nnx.Rngs, + ): + self.feature = _FeatureLayer(in_dim, feat_dim, rngs=rngs) + self.head = nnx.Linear(feat_dim, num_classes, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.head(self.feature(x)) + + +class _ToyDeepClassifier(nnx.Module): + """Toy classifier with two feature layers: FeatureLayer -> FeatureLayer -> head. + Used to test multi-leaf sowing behavior. + """ + def __init__( + self, + in_dim: int, + feat_dim1: int, + feat_dim2: int, + num_classes: int, + *, + rngs: nnx.Rngs, + ): + self.feature1 = _FeatureLayer(in_dim, feat_dim1, rngs=rngs) + self.feature2 = _FeatureLayer(feat_dim1, feat_dim2, rngs=rngs) + self.head = nnx.Linear(feat_dim2, num_classes, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.feature1(x) + x = self.feature2(x) + return self.head(x) + + +def _pop_leaves(model: nnx.Module) -> list[jax.Array]: + """Pop sowed outputs and return leaves as a Python list.""" + state = sowed_module.pop_sowed_intermediate_outputs(model) + return list(jax.tree.leaves(state)) if state else [] + + +class FeatureExtractionTest(parameterized.TestCase): + """Unit tests for feature extraction and sowed module helpers.""" + def test_avg_pool_valid_1d_exact(self): + # Shape (1, B, D) so pooling touches only the last axis. + x = jnp.arange(1, 1 + 1 * 2 * 6, dtype=jnp.float32).reshape(1, 2, 6) + y = feature_extraction.avg_pool_array_to_target_shape( + x, + target_shape=(1, 2, 2), + padding_mode=pooling.PaddingMode.VALID, + ) + self.assertEqual(y.shape, (1, 2, 2)) + + # For D=6 -> target=2: stride=3, window=3 => chunk means. + x_np = np.asarray(x) + expected = np.stack( + [ + x_np[:, :, 0:3].mean(axis=-1), + x_np[:, :, 3:6].mean(axis=-1), + ], + axis=-1, + ) + np.testing.assert_allclose(np.asarray(y), expected, atol=1e-6) + + def test_avg_pool_same_include_pad_changes_result(self): + x = jnp.array([1, 2, 3, 4, 5], dtype=jnp.float32).reshape(1, 1, 5) + + y_exclude = feature_extraction.avg_pool_array_to_target_shape( + x, + target_shape=(1, 1, 2), + padding_mode=pooling.PaddingMode.SAME, + count_include_pad_for_same_padding=False, + ) + y_include = feature_extraction.avg_pool_array_to_target_shape( + x, + target_shape=(1, 1, 2), + padding_mode=pooling.PaddingMode.SAME, + count_include_pad_for_same_padding=True, + ) + + self.assertEqual(y_exclude.shape, (1, 1, 2)) + self.assertEqual(y_include.shape, (1, 1, 2)) + + # window1: [1,2,3] => 2 + # window2: [4,5,0] + # exclude pad => (4+5)/2 = 4.5 + # include pad => (4+5+0)/3 = 3 + np.testing.assert_allclose(np.asarray(y_exclude)[0, 0, 0], 2.0, atol=1e-6) + np.testing.assert_allclose(np.asarray(y_exclude)[0, 0, 1], 4.5, atol=1e-6) + np.testing.assert_allclose(np.asarray(y_include)[0, 0, 0], 2.0, atol=1e-6) + np.testing.assert_allclose(np.asarray(y_include)[0, 0, 1], 3.0, atol=1e-6) + + def test_avg_pool_rank_mismatch_raises(self): + x = jnp.ones((2, 3), dtype=jnp.float32) + with self.assertRaises(ValueError): + feature_extraction.avg_pool_array_to_target_shape(x, target_shape=(2, 3, 1)) + + def test_avg_pool_invalid_target_dim_raises(self): + x = jnp.ones((2, 3), dtype=jnp.float32) + with self.assertRaises(ValueError): + feature_extraction.avg_pool_array_to_target_shape(x, target_shape=(2, 4)) + + def test_wrap_pop_unwrap_sowed_modules(self): + rngs = nnx.Rngs(0) + model = _ToyClassifier(in_dim=4, feat_dim=8, num_classes=3, rngs=rngs) + x = jnp.arange(8, dtype=jnp.float32).reshape(2, 4) + + original_feature_module = model.feature + sowed_module.wrap_model_with_sowed_modules(model, [_FeatureLayer]) + + _ = model(x) + + captured = sowed_module.pop_sowed_intermediate_outputs(model) + leaves = jax.tree.leaves(captured) + self.assertLen(leaves, 1) + + captured_feat = leaves[0] + expected_feat = original_feature_module(x) + np.testing.assert_allclose( + np.asarray(captured_feat), + np.asarray(expected_feat), + atol=1e-6, + ) + + sowed_module.unwrap_sowed_modules(model) + self.assertIs(model.feature, original_feature_module) + + def test_wrap_pop_unwrap_sowed_modules_multiple_leaves(self): + rngs = nnx.Rngs(0) + model = _ToyDeepClassifier( + in_dim=4, feat_dim1=6, feat_dim2=8, num_classes=3, rngs=rngs + ) + x = jnp.arange(8, dtype=jnp.float32).reshape(2, 4) + + # Baseline forward. + baseline_logits = model(x) + + # Wrap both _FeatureLayer instances. + sowed_module.wrap_model_with_sowed_modules(model, [_FeatureLayer]) + + logits = model(x) + np.testing.assert_allclose( + np.asarray(logits), np.asarray(baseline_logits), atol=1e-6 + ) + + leaves = _pop_leaves(model) + self.assertLen(leaves, 2) + + expected1 = model.feature1(x) + expected2 = model.feature2(expected1) + + leaves_np = [np.asarray(a) for a in leaves] + exp_np = [np.asarray(expected1), np.asarray(expected2)] + + def _matches_any(arr, candidates): + return any( + (arr.shape == c.shape) and np.allclose(arr, c, atol=1e-6) + for c in candidates + ) + + self.assertTrue(_matches_any(exp_np[0], leaves_np), "feature1 not captured") + self.assertTrue(_matches_any(exp_np[1], leaves_np), "feature2 not captured") + + sowed_module.unwrap_sowed_modules(model) + for _, m in model.iter_modules(): + self.assertNotIsInstance(m, sowed_module.SowedModule) + + def test_setup_and_remove_feature_projection(self): + batch_size = 2 + in_dim = 4 + num_classes = 3 + + student = _ToyClassifier( + in_dim=in_dim, + feat_dim=4, + num_classes=num_classes, + rngs=nnx.Rngs(0), + ) + teacher = _ToyClassifier( + in_dim=in_dim, + feat_dim=8, + num_classes=num_classes, + rngs=nnx.Rngs(1), + ) + + dummy_x = jnp.ones((batch_size, in_dim), dtype=jnp.float32) + + student_wrapped, teacher_wrapped = ( + feature_extraction.setup_models_with_feature_projection( + student_model=student, + teacher_model=teacher, + student_layer_to_capture=_FeatureLayer, + teacher_layer_to_capture=_FeatureLayer, + dummy_student_input={"x": dummy_x}, + dummy_teacher_input={"x": dummy_x}, + rngs=nnx.Rngs(42), + ) + ) + + self.assertIsInstance( + student_wrapped, feature_extraction.ModelWithFeatureProjection + ) + + logits, projected = student_wrapped(dummy_x) + self.assertEqual(logits.shape, (batch_size, num_classes)) + + _ = teacher_wrapped(dummy_x) + teacher_state = sowed_module.pop_sowed_intermediate_outputs(teacher_wrapped) + teacher_feats = jnp.stack(jax.tree.leaves(teacher_state)) + self.assertEqual(projected.shape, teacher_feats.shape) + + student_orig, teacher_orig = ( + feature_extraction.remove_feature_projection_from_models( + student_wrapped, teacher_wrapped + ) + ) + self.assertIsInstance(student_orig, _ToyClassifier) + self.assertIsInstance(teacher_orig, _ToyClassifier) + + for _, m in student_orig.iter_modules(): + self.assertNotIsInstance(m, sowed_module.SowedModule) + for _, m in teacher_orig.iter_modules(): + self.assertNotIsInstance(m, sowed_module.SowedModule) + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/tests/distillation/strategies_test.py b/tests/distillation/strategies_test.py new file mode 100644 index 00000000..eed02e6a --- /dev/null +++ b/tests/distillation/strategies_test.py @@ -0,0 +1,301 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for distillation strategies. + +Smoke tests for: +- LogitStrategy +- FeaturePoolingStrategy +- FeatureProjectionStrategy +- ContrastiveRepresentationDistillationStrategy (CRD) +""" + +from __future__ import annotations +from typing import Any + +from absl.testing import absltest +from absl.testing import parameterized +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +import optax + +from tunix.distillation.feature_extraction import sowed_module +from tunix.distillation.strategies.crd_strategy import ( + ContrastiveRepresentationDistillationStrategy, +) +from tunix.distillation.strategies.feature_pooling import FeaturePoolingStrategy +from tunix.distillation.strategies.feature_projection import FeatureProjectionStrategy +from tunix.distillation.strategies.logit import LogitStrategy + +# Tiny toy models for testing. +class _FeatureLayer(nnx.Module): + def __init__(self, in_dim: int, feat_dim: int, *, rngs: nnx.Rngs): + self.proj = nnx.Linear(in_dim, feat_dim, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return jax.nn.tanh(self.proj(x)) + + +class _ToyClassifier(nnx.Module): + def __init__(self, in_dim: int, feat_dim: int, num_classes: int, *, rngs: nnx.Rngs): + self.feature = _FeatureLayer(in_dim, feat_dim, rngs=rngs) + self.head = nnx.Linear(feat_dim, num_classes, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + h = self.feature(x) + return self.head(h) + + +def _forward_logits(model: nnx.Module, x: jax.Array, labels: jax.Array) -> jax.Array: + del labels + return model(x) + + +def _forward_logits_and_features(model: nnx.Module, x: jax.Array, labels: jax.Array): + del labels + # For FeatureProjectionStrategy, processed student model returns: + # (logits, projected_features) + return model(x) + + +def _forward_logits_and_embedding(model: nnx.Module, x: jax.Array, labels: jax.Array): + del labels + # For CRD, processed student model returns: + # (logits, z_student) + return model(x) + + +def _labels_passthrough(x: jax.Array, labels: jax.Array) -> jax.Array: + del x + return labels + + +def _has_sowed_modules(model: nnx.Module) -> bool: + return any(isinstance(m, sowed_module.SowedModule) for _, m in model.iter_modules()) + + +def _assert_no_sowed_modules(testcase: absltest.TestCase, model: nnx.Module): + for _, m in model.iter_modules(): + testcase.assertNotIsInstance(m, sowed_module.SowedModule) + + +def _assert_teacher_output_has_batch( + testcase: absltest.TestCase, + teacher_out: Any, + batch_size: int, +): + """Accept either stacked-array teacher outputs or PyTree sowed state.""" + # Case 1: Array-like output with .shape (e.g., (N, B, ...)). + if hasattr(teacher_out, "shape"): + shape = teacher_out.shape + testcase.assertGreaterEqual(len(shape), 2) + testcase.assertEqual(shape[1], batch_size) + return + + # Case 2: PyTree output (e.g., sowed state). Leaves should be (B, ...). + leaves = list(jax.tree.leaves(teacher_out)) + testcase.assertNotEmpty(leaves) + for leaf in leaves: + testcase.assertTrue(hasattr(leaf, "shape")) + testcase.assertGreaterEqual(len(leaf.shape), 1) + testcase.assertEqual(leaf.shape[0], batch_size) + +class StrategiesTest(parameterized.TestCase): + def test_logit_strategy_matches_manual_kl(self): + # Sanity check: compare to a manual computation of the same formula. + temperature = 2.0 + alpha = 0.7 + strat = LogitStrategy( + _forward_logits, + _forward_logits, + _labels_passthrough, + temperature=temperature, + alpha=alpha, + ) + + student_logits = jnp.array([[1.0, 0.0, -1.0]], dtype=jnp.float32) + teacher_logits = jnp.array([[0.5, 0.25, -0.75]], dtype=jnp.float32) + labels = jax.nn.one_hot(jnp.array([0]), 3) + + loss = strat.compute_loss(student_logits, teacher_logits, labels) + self.assertEqual(loss.shape, ()) + + log_student_probs_temp = jax.nn.log_softmax( + student_logits / temperature, axis=-1 + ) + teacher_probs_temp = jax.nn.softmax(teacher_logits / temperature, axis=-1) + kl = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp) * ( + temperature**2 + ) + distill_loss = jnp.mean(kl) + task_loss = jnp.mean(optax.softmax_cross_entropy(student_logits, labels)) + expected = alpha * distill_loss + (1.0 - alpha) * task_loss + + np.testing.assert_allclose(np.asarray(loss), np.asarray(expected), atol=1e-6) + + def test_feature_pooling_strategy_smoke(self): + batch_size = 2 + in_dim = 4 + num_classes = 3 + + student = _ToyClassifier( + in_dim, feat_dim=4, num_classes=num_classes, rngs=nnx.Rngs(0) + ) + teacher = _ToyClassifier( + in_dim, feat_dim=8, num_classes=num_classes, rngs=nnx.Rngs(1) + ) + + strat = FeaturePoolingStrategy( + student_forward_fn=_forward_logits, + teacher_forward_fn=_forward_logits, + labels_fn=_labels_passthrough, + feature_layer=_FeatureLayer, + alpha=0.5, + ) + + student_p, teacher_p = strat.pre_process_models(student, teacher) + self.assertTrue(_has_sowed_modules(student_p)) + self.assertTrue(_has_sowed_modules(teacher_p)) + + x = ( + jnp.arange(batch_size * in_dim, dtype=jnp.float32).reshape( + batch_size, in_dim + ) + / 10.0 + ) + labels = jax.nn.one_hot(jnp.array([0, 2]), num_classes) + + teacher_out = strat.get_teacher_outputs(teacher_p, {"x": x, "labels": labels}) + _assert_teacher_output_has_batch(self, teacher_out, batch_size) + + loss = strat.get_train_loss(student_p, teacher_out, {"x": x, "labels": labels}) + self.assertEqual(loss.shape, ()) + + eval_loss = strat.get_eval_loss(student_p, {"x": x, "labels": labels}) + self.assertEqual(eval_loss.shape, ()) + + student_o, teacher_o = strat.post_process_models(student_p, teacher_p) + _assert_no_sowed_modules(self, student_o) + _assert_no_sowed_modules(self, teacher_o) + + def test_feature_projection_strategy_smoke(self): + # Keep training batch == dummy batch in this test. + batch_size = 2 + in_dim = 4 + num_classes = 3 + + student = _ToyClassifier( + in_dim, feat_dim=4, num_classes=num_classes, rngs=nnx.Rngs(0) + ) + teacher = _ToyClassifier( + in_dim, feat_dim=8, num_classes=num_classes, rngs=nnx.Rngs(1) + ) + + dummy_x = jnp.ones((batch_size, in_dim), dtype=jnp.float32) + + strat = FeatureProjectionStrategy( + student_forward_fn=_forward_logits_and_features, + teacher_forward_fn=_forward_logits, + labels_fn=_labels_passthrough, + feature_layer=_FeatureLayer, + dummy_input={"x": dummy_x}, + rngs=nnx.Rngs(42), + alpha=0.5, + ) + + student_p, teacher_p = strat.pre_process_models(student, teacher) + self.assertTrue(_has_sowed_modules(student_p)) + self.assertTrue(_has_sowed_modules(teacher_p)) + + x = ( + jnp.arange(batch_size * in_dim, dtype=jnp.float32).reshape( + batch_size, in_dim + ) + + 1.0 + ) / 10.0 + labels = jax.nn.one_hot(jnp.array([1, 0]), num_classes) + + teacher_out = strat.get_teacher_outputs(teacher_p, {"x": x, "labels": labels}) + _assert_teacher_output_has_batch(self, teacher_out, batch_size) + + loss = strat.get_train_loss(student_p, teacher_out, {"x": x, "labels": labels}) + self.assertEqual(loss.shape, ()) + + student_o, teacher_o = strat.post_process_models(student_p, teacher_p) + self.assertIsInstance(student_o, _ToyClassifier) + self.assertIsInstance(teacher_o, _ToyClassifier) + _assert_no_sowed_modules(self, student_o) + _assert_no_sowed_modules(self, teacher_o) + + def test_contrastive_representation_distillation_strategy_smoke(self): + batch_size = 2 + in_dim = 4 + num_classes = 3 + + student = _ToyClassifier( + in_dim, feat_dim=4, num_classes=num_classes, rngs=nnx.Rngs(0) + ) + teacher = _ToyClassifier( + in_dim, feat_dim=8, num_classes=num_classes, rngs=nnx.Rngs(1) + ) + + dummy_x = jnp.ones((batch_size, in_dim), dtype=jnp.float32) + + strat = ContrastiveRepresentationDistillationStrategy( + student_forward_fn=_forward_logits_and_embedding, + teacher_forward_fn=_forward_logits, + labels_fn=_labels_passthrough, + student_layer_to_capture=_FeatureLayer, + teacher_layer_to_capture=_FeatureLayer, + dummy_student_input={"x": dummy_x}, + dummy_teacher_input={"x": dummy_x}, + rngs=nnx.Rngs(42), + embedding_dim=16, + mlp_hidden_dim=32, + temperature=0.2, + alpha=0.5, + symmetric=False, + ) + + student_p, teacher_p = strat.pre_process_models(student, teacher) + self.assertTrue(_has_sowed_modules(student_p)) + self.assertTrue(_has_sowed_modules(teacher_p)) + + x = ( + jnp.arange(batch_size * in_dim, dtype=jnp.float32).reshape( + batch_size, in_dim + ) + + 1.0 + ) / 10.0 + labels = jax.nn.one_hot(jnp.array([1, 0]), num_classes) + + teacher_out = strat.get_teacher_outputs(teacher_p, {"x": x, "labels": labels}) + _assert_teacher_output_has_batch(self, teacher_out, batch_size) + + loss = strat.get_train_loss(student_p, teacher_out, {"x": x, "labels": labels}) + self.assertEqual(loss.shape, ()) + + eval_loss = strat.get_eval_loss(student_p, {"x": x, "labels": labels}) + self.assertEqual(eval_loss.shape, ()) + + student_o, teacher_o = strat.post_process_models(student_p, teacher_p) + self.assertIsInstance(student_o, _ToyClassifier) + self.assertIsInstance(teacher_o, _ToyClassifier) + _assert_no_sowed_modules(self, student_o) + _assert_no_sowed_modules(self, teacher_o) + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/tunix/distillation/__init__.py b/tunix/distillation/__init__.py index dfe76e59..b94482eb 100644 --- a/tunix/distillation/__init__.py +++ b/tunix/distillation/__init__.py @@ -18,3 +18,6 @@ from tunix.distillation.distillation_trainer import DistillationTrainer from tunix.distillation.distillation_trainer import TrainingConfig from tunix.distillation.distillation_trainer import TrainingInput +from tunix.distillation.strategies.crd_strategy import ( + ContrastiveRepresentationDistillationStrategy, +) \ No newline at end of file diff --git a/tunix/distillation/strategies/crd_strategy.py b/tunix/distillation/strategies/crd_strategy.py new file mode 100644 index 00000000..a7143932 --- /dev/null +++ b/tunix/distillation/strategies/crd_strategy.py @@ -0,0 +1,451 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contrastive Representation Distillation (CRD) strategy. + +Implements an InfoNCE-style contrastive loss between student and teacher +representations captured from intermediate layers via `sowed_module`. +""" + +from __future__ import annotations +from typing import Any, Callable + +from flax import nnx +import jax +import jax.numpy as jnp +import optax +from typing_extensions import override + +from tunix.distillation.feature_extraction import sowed_module +from tunix.distillation.strategies import base_strategy +ModelForwardCallable = base_strategy.ModelForwardCallable + +def _l2_normalize( + x: jax.Array, + axis: int = -1, + eps: float = 1e-6, +) -> jax.Array: + """L2-normalize along `axis` with epsilon for numerical stability.""" + norm = jnp.linalg.norm(x, axis=axis, keepdims=True) + return x / (norm + eps) + + +def _pool_feature_to_representation( + feat: jax.Array, + *, + input_mask: jax.Array | None = None, + mask_axis: int = 1, + eps: float = 1e-6, +) -> jax.Array: + """Convert a single sowed feature into (B, C). + + Supported: + (B, C) -> identity + (B, T, C) with mask (B, T) -> masked mean + (B, ..., C) -> mean over all non-batch, non-channel dims + """ + if feat.ndim < 2: + raise ValueError( + f"Feature must have at least 2 dims (B, ...). Got {feat.shape}" + ) + + if feat.ndim == 2: + return feat + + if input_mask is not None and feat.ndim == 3: + if input_mask.ndim == 2 and input_mask.shape[:2] == feat.shape[:2]: + mask = input_mask.astype(feat.dtype)[..., None] # (B, T, 1) + summed = jnp.sum(feat * mask, axis=mask_axis) # (B, C) + denom = jnp.sum(mask, axis=mask_axis) # (B, 1) + return summed / (denom + eps) + + reduce_axes = tuple(range(1, feat.ndim - 1)) + return jnp.mean(feat, axis=reduce_axes) + + +def _sowed_state_to_pooled_stack( + sowed_state: Any, + *, + input_mask: jax.Array | None, + eps: float, +) -> jax.Array: + leaves = list(jax.tree.leaves(sowed_state)) + if not leaves: + raise ValueError("No sowed intermediates found.") + + pooled = [ + _pool_feature_to_representation(x, input_mask=input_mask, eps=eps) + for x in leaves + ] # list of (B, C) + + # Validate consistent channel dimension to allow stacking. + c0 = int(pooled[0].shape[-1]) + for i, p in enumerate(pooled): + if p.ndim != 2: + raise ValueError( + f"Expected pooled rep (B, C) but got {p.shape} at leaf {i}." + ) + if int(p.shape[-1]) != c0: + raise ValueError( + "CRD currently requires all captured leaves to have the same channel " + "dimension after pooling so they can be stacked. " + f"Got C0={c0} and leaf[{i}].C={int(p.shape[-1])}. " + "Suggestion: capture a single layer type, or ensure consistent dims." + ) + + return jnp.stack(pooled, axis=0) # (N, B, C) + +def _stacked_pooled_to_representation(stacked: jax.Array) -> jax.Array: + """Convert pooled stack (N, B, C) into (B, C) by averaging N.""" + if stacked.ndim != 3: + raise ValueError( + f"Expected stacked pooled features (N, B, C). Got {stacked.shape}" + ) + if stacked.shape[0] == 1: + return stacked[0] + return jnp.mean(stacked, axis=0) + +class _ProjectionHead(nnx.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + *, + rngs: nnx.Rngs, + hidden_dim: int | None = None, + ): + self.hidden_dim = hidden_dim + if hidden_dim is None: + self.fc = nnx.Linear(in_dim, out_dim, rngs=rngs) + else: + self.fc1 = nnx.Linear(in_dim, hidden_dim, rngs=rngs) + self.fc2 = nnx.Linear(hidden_dim, out_dim, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + if self.hidden_dim is None: + return self.fc(x) + x = jax.nn.gelu(self.fc1(x)) + return self.fc2(x) + + +class StudentModelWithCRDHeads(nnx.Module): + """Wrap a student model and add CRD projection heads. + """ + + def __init__( + self, + model: nnx.Module, + *, + student_rep_dim: int, + teacher_rep_dim: int, + embedding_dim: int, + rngs: nnx.Rngs, + mlp_hidden_dim: int | None = 512, + eps: float = 1e-6, + mask_key: str = "input_mask", + ): + self.model = model + self.student_head = _ProjectionHead( + student_rep_dim, + embedding_dim, + rngs=rngs, + hidden_dim=mlp_hidden_dim, + ) + self.teacher_head = _ProjectionHead( + teacher_rep_dim, + embedding_dim, + rngs=rngs, + hidden_dim=mlp_hidden_dim, + ) + self.eps = eps + self.mask_key = mask_key + + def __call__(self, *args, **kwargs) -> tuple[jax.Array, jax.Array]: + logits = self.model(*args, **kwargs) + + s_state = sowed_module.pop_sowed_intermediate_outputs(self.model) + if not s_state: + raise ValueError( + "No sowed intermediates found for student model. " + "Did you wrap the intended layers with sowed_module?" + ) + + input_mask = kwargs.get(self.mask_key) + s_stacked = _sowed_state_to_pooled_stack( + s_state, input_mask=input_mask, eps=self.eps + ) + s_rep = _stacked_pooled_to_representation(s_stacked) # (B, C) + + z_s = _l2_normalize(self.student_head(s_rep), eps=self.eps) + return logits, z_s + + def embed_teacher_features( + self, + teacher_stacked_features: jax.Array, + *, + input_mask: jax.Array | None = None, + ) -> jax.Array: + del input_mask + t_rep = _stacked_pooled_to_representation(teacher_stacked_features) # (B, C) + z_t = _l2_normalize(self.teacher_head(t_rep), eps=self.eps) + return z_t + +def _infonce_loss( + z_student: jax.Array, + z_teacher: jax.Array, + *, + temperature: float, +) -> jax.Array: + """InfoNCE with in-batch negatives: positive pairs are i->i.""" + logits = (z_student @ z_teacher.T) / temperature # (B, B) + bsz = logits.shape[0] + labels = jnp.arange(bsz) + onehot = jax.nn.one_hot(labels, bsz) + loss = optax.softmax_cross_entropy(logits=logits, labels=onehot) + return jnp.mean(loss) + +def _setup_models_for_crd( + student_model: nnx.Module, + teacher_model: nnx.Module, + *, + student_layer_to_capture: type[nnx.Module], + teacher_layer_to_capture: type[nnx.Module], + dummy_student_input: dict[str, Any], + dummy_teacher_input: dict[str, Any], + rngs: nnx.Rngs, + embedding_dim: int, + mlp_hidden_dim: int | None, + mask_key: str, + eps: float = 1e-6, +) -> tuple[StudentModelWithCRDHeads, nnx.Module]: + """Wrap with sowed capture, infer rep dims via dummy runs, add CRD heads.""" + sowed_module.wrap_model_with_sowed_modules( + student_model, [student_layer_to_capture] + ) + sowed_module.wrap_model_with_sowed_modules( + teacher_model, [teacher_layer_to_capture] + ) + + student_model(**dummy_student_input) + teacher_model(**dummy_teacher_input) + + s_state = sowed_module.pop_sowed_intermediate_outputs(student_model) + t_state = sowed_module.pop_sowed_intermediate_outputs(teacher_model) + + if not s_state: + raise ValueError( + "No sowed intermediates found for student dummy run. " + "Check student_layer_to_capture." + ) + if not t_state: + raise ValueError( + "No sowed intermediates found for teacher dummy run. " + "Check teacher_layer_to_capture." + ) + + s_mask = dummy_student_input.get(mask_key) + t_mask = dummy_teacher_input.get(mask_key) + + s_stacked = _sowed_state_to_pooled_stack( + s_state, input_mask=s_mask, eps=eps + ) # (N, B, C) + t_stacked = _sowed_state_to_pooled_stack( + t_state, input_mask=t_mask, eps=eps + ) # (N, B, C) + + s_rep = _stacked_pooled_to_representation(s_stacked) # (B, C) + t_rep = _stacked_pooled_to_representation(t_stacked) # (B, C) + + if s_rep.ndim != 2 or t_rep.ndim != 2: + raise ValueError( + "CRD expects pooled representations to be 2D (B, C). " + f"Got student_rep={s_rep.shape}, teacher_rep={t_rep.shape}" + ) + + wrapped_student = StudentModelWithCRDHeads( + student_model, + student_rep_dim=int(s_rep.shape[-1]), + teacher_rep_dim=int(t_rep.shape[-1]), + embedding_dim=int(embedding_dim), + rngs=rngs, + mlp_hidden_dim=mlp_hidden_dim, + eps=eps, + mask_key=mask_key, + ) + return wrapped_student, teacher_model + + +def _remove_crd_from_models( + student_model: nnx.Module, + teacher_model: nnx.Module, +) -> tuple[nnx.Module, nnx.Module]: + """Unwrap sowed modules and return original models.""" + if isinstance(student_model, StudentModelWithCRDHeads): + base_student = student_model.model + else: + base_student = student_model + + sowed_module.unwrap_sowed_modules(base_student) + sowed_module.unwrap_sowed_modules(teacher_model) + return base_student, teacher_model + +class ContrastiveRepresentationDistillationStrategy(base_strategy.BaseStrategy): + """CRD: contrastive loss between student/teacher intermediates.""" + def __init__( + self, + student_forward_fn: ModelForwardCallable[Any], + teacher_forward_fn: ModelForwardCallable[Any], + labels_fn: Callable[..., jax.Array], + *, + student_layer_to_capture: type[nnx.Module], + teacher_layer_to_capture: type[nnx.Module], + dummy_student_input: dict[str, jax.Array], + dummy_teacher_input: dict[str, jax.Array], + rngs: nnx.Rngs, + embedding_dim: int = 128, + mlp_hidden_dim: int | None = 512, + temperature: float = 0.2, + alpha: float = 0.75, + symmetric: bool = False, + mask_key: str = "input_mask", + eps: float = 1e-6, + ): + super().__init__(student_forward_fn, teacher_forward_fn, labels_fn) + + if temperature <= 0: + raise ValueError(f"temperature must be > 0, got {temperature}") + if not 0.0 <= alpha <= 1.0: + raise ValueError(f"alpha must be in [0, 1], got {alpha}") + + self.student_layer_to_capture = student_layer_to_capture + self.teacher_layer_to_capture = teacher_layer_to_capture + self.dummy_student_input = dummy_student_input + self.dummy_teacher_input = dummy_teacher_input + self.rngs = rngs + self.embedding_dim = int(embedding_dim) + self.mlp_hidden_dim = mlp_hidden_dim + self.temperature = float(temperature) + self.alpha = float(alpha) + self.symmetric = bool(symmetric) + self.mask_key = mask_key + self.eps = float(eps) + + @override + def pre_process_models( + self, + student_model: nnx.Module, + teacher_model: nnx.Module, + ) -> tuple[nnx.Module, nnx.Module]: + return _setup_models_for_crd( + student_model, + teacher_model, + student_layer_to_capture=self.student_layer_to_capture, + teacher_layer_to_capture=self.teacher_layer_to_capture, + dummy_student_input=self.dummy_student_input, + dummy_teacher_input=self.dummy_teacher_input, + rngs=self.rngs, + embedding_dim=self.embedding_dim, + mlp_hidden_dim=self.mlp_hidden_dim, + mask_key=self.mask_key, + eps=self.eps, + ) + + @override + def post_process_models( + self, + student_model: nnx.Module, + teacher_model: nnx.Module, + ) -> tuple[nnx.Module, nnx.Module]: + return _remove_crd_from_models(student_model, teacher_model) + + @override + def get_teacher_outputs( + self, + teacher_model: nnx.Module, + inputs: dict[str, jax.Array], + ) -> jax.Array: + self._teacher_forward_fn(teacher_model, **inputs) + + t_state = sowed_module.pop_sowed_intermediate_outputs(teacher_model) + if not t_state: + raise ValueError( + "No sowed intermediates found for teacher forward pass. " + "Did you wrap the intended layers with sowed_module?" + ) + + # Produce pooled stacked features (N, B, C) and stop gradient. + input_mask = inputs.get(self.mask_key) + t_stacked = _sowed_state_to_pooled_stack( + t_state, input_mask=input_mask, eps=self.eps + ) + return jax.lax.stop_gradient(t_stacked) + + @override + def get_train_loss( + self, + student_model: nnx.Module, + teacher_output: jax.Array, + inputs: dict[str, jax.Array], + ) -> jax.Array: + if not isinstance(student_model, StudentModelWithCRDHeads): + raise TypeError( + "CRD expects student_model to be StudentModelWithCRDHeads. " + "Did pre_process_models run?" + ) + + student_logits, z_s = self._student_forward_fn(student_model, **inputs) + + z_t = student_model.embed_teacher_features( + jax.lax.stop_gradient(teacher_output) + ) + + crd_loss = _infonce_loss(z_s, z_t, temperature=self.temperature) + if self.symmetric: + crd_loss = 0.5 * ( + crd_loss + _infonce_loss(z_t, z_s, temperature=self.temperature) + ) + + labels = self._labels_fn(**inputs) + task_loss = jnp.mean( + optax.softmax_cross_entropy(logits=student_logits, labels=labels) + ) + return (self.alpha * crd_loss) + ((1.0 - self.alpha) * task_loss) + + @override + def get_eval_loss( + self, + student_model: nnx.Module, + inputs: dict[str, jax.Array], + ) -> jax.Array: + out = self._student_forward_fn(student_model, **inputs) + student_logits = out[0] if isinstance(out, (tuple, list)) else out + labels = self._labels_fn(**inputs) + return jnp.mean( + optax.softmax_cross_entropy(logits=student_logits, labels=labels) + ) + + def compute_loss( + self, + student_output: Any, + teacher_output: Any, + labels: jax.Array, + ) -> jax.Array: + raise NotImplementedError("CRD uses get_train_loss override.") + + def compute_eval_loss( + self, + student_output: Any, + labels: jax.Array, + ) -> jax.Array: + raise NotImplementedError("CRD uses get_eval_loss override.") \ No newline at end of file