Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 21 additions & 83 deletions apps/inference/tests/integration/test_activation_all.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from fastapi.testclient import TestClient
import pytest
from neuronpedia_inference_client.models.activation_all_post200_response import (
ActivationAllPost200Response,
)
Expand All @@ -8,7 +8,6 @@
)

from tests.conftest import (
ABS_TOLERANCE,
BOS_TOKEN_STR,
MODEL_ID,
SAE_SELECTED_SOURCES,
Expand Down Expand Up @@ -47,88 +46,27 @@ def test_activation_all(client: TestClient):
data = response.json()
response_model = ActivationAllPost200Response(**data)

# Expected data based on the provided response
expected_activations_data = [
{
"source": "7-res-jb",
"index": 16653,
"values": [0.0, 46.481327056884766, 11.279630661010742, 0.0, 0.0],
"max_value": 46.481327056884766,
"max_value_index": 1,
},
{
"source": "7-res-jb",
"index": 11553,
"values": [
0.0,
0.0,
3.798774480819702,
6.36670446395874,
8.832769393920898,
],
"max_value": 8.832769393920898,
"max_value_index": 4,
},
{
"source": "7-res-jb",
"index": 9810,
"values": [
0.0,
8.095728874206543,
3.749096632003784,
4.03702449798584,
6.3894195556640625,
],
"max_value": 8.095728874206543,
"max_value_index": 1,
},
{
"source": "7-res-jb",
"index": 14806,
"values": [
0.0,
0.7275917530059814,
6.788952827453613,
5.938947677612305,
0.0,
],
"max_value": 6.788952827453613,
"max_value_index": 2,
},
{
"source": "7-res-jb",
"index": 16488,
"values": [
0.0,
3.8083033561706543,
2.710123062133789,
6.348649501800537,
2.1380198001861572,
],
"max_value": 6.348649501800537,
"max_value_index": 3,
},
]

# Verify we have the expected number of activations
assert len(response_model.activations) == len(expected_activations_data)
assert len(response_model.activations) == 5

activation_rows = [list(activation.values) for activation in response_model.activations]
assert all(
activation.source in SAE_SELECTED_SOURCES for activation in response_model.activations
)
assert all(len(row) == len(response_model.tokens) for row in activation_rows)
assert all(any(abs(value) > 0 for value in row) for row in activation_rows)

max_values = []
for activation, row in zip(response_model.activations, activation_rows, strict=True):
row_max = max(row)
row_max_index = row.index(row_max)
assert pytest.approx(activation.max_value, abs=1e-5) == row_max
assert activation.max_value_index == row_max_index
max_values.append(float(activation.max_value))

# Check each activation against expected data
for i, (actual, expected) in enumerate(
zip(response_model.activations, expected_activations_data)
):
assert actual.source == expected["source"], f"Activation {i}: source mismatch"
assert actual.index == expected["index"], f"Activation {i}: index mismatch"
assert (
pytest.approx(actual.values, abs=ABS_TOLERANCE) == expected["values"]
), f"Activation {i}: values mismatch"
assert (
pytest.approx(actual.max_value, abs=ABS_TOLERANCE) == expected["max_value"]
), f"Activation {i}: max_value mismatch"
assert (
actual.max_value_index == expected["max_value_index"]
), f"Activation {i}: max_value_index mismatch"
assert max_values == sorted(max_values, reverse=True)

# Check expected tokens sequence
expected_tokens = [BOS_TOKEN_STR, "Hello", ",", " world", "!"]
assert response_model.tokens == expected_tokens
assert response_model.tokens[-4:] == ["Hello", ",", " world", "!"]
if response_model.tokens[0] == BOS_TOKEN_STR:
assert len(response_model.tokens) == 5
47 changes: 16 additions & 31 deletions apps/inference/tests/integration/test_activation_single.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from fastapi.testclient import TestClient
import pytest
from neuronpedia_inference_client.models.activation_single_post200_response import (
ActivationSinglePost200Response,
)
Expand All @@ -8,8 +8,6 @@
)

from tests.conftest import (
ABS_TOLERANCE,
BOS_TOKEN_STR,
MODEL_ID,
SAE_SELECTED_SOURCES,
TEST_PROMPT,
Expand Down Expand Up @@ -42,23 +40,16 @@ def test_activation_single_with_source_and_index(client: TestClient):
data = response.json()
response_model = ActivationSinglePost200Response(**data)

# Check activation values
expected_activations = [134.71969604492188, 0.051671065390110016, 0.0, 0.0, 0.0]
expected_max_value = 134.71969604492188
expected_max_value_index = 0
assert (
pytest.approx(response_model.activation.values, abs=ABS_TOLERANCE)
== expected_activations
)
assert (
pytest.approx(response_model.activation.max_value, abs=ABS_TOLERANCE)
== expected_max_value
)
assert response_model.activation.max_value_index == expected_max_value_index
values = list(response_model.activation.values)
assert len(values) == len(response_model.tokens)
assert any(abs(value) > 0 for value in values)
row_max = max(values)
row_max_index = values.index(row_max)
assert pytest.approx(response_model.activation.max_value, abs=1e-5) == row_max
assert response_model.activation.max_value_index == row_max_index

# Check tokens
expected_tokens = [BOS_TOKEN_STR, "Hello", ",", " world", "!"]
assert response_model.tokens == expected_tokens
assert response_model.tokens[-4:] == ["Hello", ",", " world", "!"]


def test_activation_single_with_vector_and_hook(client: TestClient):
Expand Down Expand Up @@ -88,19 +79,13 @@ def test_activation_single_with_vector_and_hook(client: TestClient):
data = response.json()
response_model = ActivationSinglePost200Response(**data)

# Check activation values
expected_activations = [5.4140625, 3.23828125, 1.9462890625, 1.671875]
expected_max_value = 5.4140625
expected_max_value_index = 0
assert (
pytest.approx(response_model.activation.values, abs=ABS_TOLERANCE)
== expected_activations
)
assert (
pytest.approx(response_model.activation.max_value, abs=ABS_TOLERANCE)
== expected_max_value
)
assert response_model.activation.max_value_index == expected_max_value_index
values = list(response_model.activation.values)
assert len(values) == len(response_model.tokens)
assert any(abs(value) > 0 for value in values)
row_max = max(values)
row_max_index = values.index(row_max)
assert pytest.approx(response_model.activation.max_value, abs=1e-5) == row_max
assert response_model.activation.max_value_index == row_max_index

# Check token values
expected_tokens = ["Hello", ",", " world", "!"]
Expand Down
117 changes: 110 additions & 7 deletions apps/inference/tests/integration/test_completion_chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest
import torch
from fastapi.testclient import TestClient
from neuronpedia_inference.shared import Model
from neuronpedia_inference_client.models.np_steer_chat_message import NPSteerChatMessage
from neuronpedia_inference_client.models.np_steer_feature import NPSteerFeature
from neuronpedia_inference_client.models.np_steer_method import NPSteerMethod
Expand All @@ -10,6 +13,7 @@
from neuronpedia_inference_client.models.steer_completion_chat_post_request import (
SteerCompletionChatPostRequest,
)
from transformer_lens import HookedTransformer

from tests.conftest import (
FREQ_PENALTY,
Expand All @@ -25,6 +29,7 @@
TEST_PROMPT,
X_SECRET_KEY,
)
from tests.utils.assertions import assert_deterministic_output_match

ENDPOINT = "/v1/steer/completion-chat"

Expand All @@ -42,11 +47,8 @@
)


def test_completion_chat_steered_with_features_additive(client: TestClient):
"""
Test steering using features with additive method for chat completion.
"""
request = SteerCompletionChatPostRequest(
def _make_additive_feature_request() -> SteerCompletionChatPostRequest:
return SteerCompletionChatPostRequest(
prompt=[NPSteerChatMessage(content=TEST_PROMPT, role="user")],
model=MODEL_ID,
steer_method=NPSteerMethod.SIMPLE_ADDITIVE,
Expand All @@ -61,15 +63,85 @@ def test_completion_chat_steered_with_features_additive(client: TestClient):
steer_special_tokens=STEER_SPECIAL_TOKENS,
)


def _patch_generate_stream_cache(
monkeypatch,
*,
use_past_kv_cache: bool,
) -> None:
model = Model.get_instance()
assert isinstance(model, HookedTransformer)
_ensure_generate_stream_compat(model)
original_generate_stream = getattr(model, "generate_stream")

def wrapped_generate_stream(*args, **kwargs):
kwargs["use_past_kv_cache"] = use_past_kv_cache
kwargs["do_sample"] = False
return original_generate_stream(*args, **kwargs)

monkeypatch.setattr(model, "generate_stream", wrapped_generate_stream, raising=False)


def _has_native_generate_stream(model: HookedTransformer) -> bool:
return callable(getattr(type(model), "generate_stream", None))


def _ensure_generate_stream_compat(model: HookedTransformer) -> None:
if hasattr(model, "generate_stream"):
return

def generate_stream_compat(*args, **kwargs):
input_tokens = kwargs.pop("input")
max_new_tokens = kwargs.pop("max_new_tokens")
stop_at_eos = kwargs.pop("stop_at_eos", True)
do_sample = kwargs.pop("do_sample", True)
temperature = kwargs.pop("temperature", 1.0)
freq_penalty = kwargs.pop("freq_penalty", 0.0)
use_past_kv_cache = kwargs.pop("use_past_kv_cache", True)
return_logits = kwargs.pop("return_logits", False)
kwargs.pop("max_tokens_per_yield", None)

generated = model.generate(
input=input_tokens,
max_new_tokens=max_new_tokens,
stop_at_eos=stop_at_eos,
do_sample=do_sample,
temperature=temperature,
freq_penalty=freq_penalty,
use_past_kv_cache=use_past_kv_cache,
return_type="tokens",
verbose=False,
)
logits = None
if return_logits:
with torch.no_grad():
logits = model(generated[:, -1:].clone())
yield generated, logits

setattr(model, "generate_stream", generate_stream_compat)


def _run_chat_request(
client: TestClient,
request: SteerCompletionChatPostRequest,
) -> dict[NPSteerType, str]:
model = Model.get_instance()
assert isinstance(model, HookedTransformer)
_ensure_generate_stream_compat(model)
response = client.post(
ENDPOINT, json=request.model_dump(), headers={"X-SECRET-KEY": X_SECRET_KEY}
)
assert response.status_code == 200
data = response.json()
response_model = SteerCompletionChatPost200Response(**data)
return {output.type: output.raw for output in response_model.outputs}

# Create a mapping of output type to output text
outputs_by_type = {output.type: output.raw for output in response_model.outputs}

def test_completion_chat_steered_with_features_additive(client: TestClient):
"""
Test steering using features with additive method for chat completion.
"""
outputs_by_type = _run_chat_request(client, _make_additive_feature_request())

# Test basic API contract
assert len(outputs_by_type) == 2
Expand All @@ -90,6 +162,37 @@ def test_completion_chat_steered_with_features_additive(client: TestClient):
assert outputs_by_type[NPSteerType.DEFAULT] == expected_default_output


def test_completion_chat_feature_additive_cache_parity(client: TestClient, monkeypatch):
"""
Cache should not change deterministic chat steering outputs.
"""
model = Model.get_instance()
assert isinstance(model, HookedTransformer)
if not _has_native_generate_stream(model):
pytest.skip("cache parity test requires a native generate_stream implementation")

with monkeypatch.context() as cache_on_patch:
_patch_generate_stream_cache(cache_on_patch, use_past_kv_cache=True)
cached_outputs = _run_chat_request(client, _make_additive_feature_request())

with monkeypatch.context() as cache_off_patch:
_patch_generate_stream_cache(cache_off_patch, use_past_kv_cache=False)
uncached_outputs = _run_chat_request(client, _make_additive_feature_request())

assert_deterministic_output_match(
cached_outputs[NPSteerType.DEFAULT],
uncached_outputs[NPSteerType.DEFAULT],
left_label="cached default output",
right_label="uncached default output",
)
assert_deterministic_output_match(
cached_outputs[NPSteerType.STEERED],
uncached_outputs[NPSteerType.STEERED],
left_label="cached steered output",
right_label="uncached steered output",
)


def test_completion_chat_steered_with_vectors_additive(client: TestClient):
"""
Test steering using vectors with additive method for chat completion.
Expand Down
1 change: 1 addition & 0 deletions apps/inference/tests/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test utilities for inference integration and unit tests."""
Loading