Skip to content

feat: Add some preliminary unit testing support #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ dev-dependencies = [
"hatch>=1.14.1",
"ruff>=0.12.1",
"pyright>=1.1.403",
"pytest>=8.4.1",
"nbval>=0.11.0",
"pytest-xdist>=3.8.0",
]

[tool.uv.sources]
Expand Down
71 changes: 69 additions & 2 deletions scripts/run_checks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fi
# Track if any checks fail
CHECKS_PASSED=true
TYPECHECK_FAILED=false
TESTS_FAILED=false

# Run format check
echo "📝 Checking code formatting..."
Expand Down Expand Up @@ -75,7 +76,7 @@ echo
# Run type checking (Pyright)
echo "🧠 Running type checking..."
TMP_PYRIGHT_JSON=$(mktemp)
echo " Running: uv run pyright --outputjson src"
echo " Running: uv run pyright --outputjson src tests"
# Capture JSON output quietly regardless of success/failure
if uv run pyright --outputjson src > "$TMP_PYRIGHT_JSON" 2>/dev/null; then
: # success, continue
Expand Down Expand Up @@ -125,6 +126,66 @@ fi
rm -f "$TMP_PYRIGHT_JSON"
echo

# Run tests
echo "🧪 Running unit tests..."
echo " Running: uv run pytest --nbval tests/unit"

# Capture pytest output quietly to parse the summary
PYTEST_OUTPUT=$(mktemp)
if uv run pytest --nbval --tb=short tests/unit > "$PYTEST_OUTPUT" 2>&1; then
TEST_EXIT_CODE=0
else
TEST_EXIT_CODE=$?
fi

# Extract the test summary line (e.g., "===== 5 passed, 2 failed, 1 skipped in 3.45s =====")
# This regex captures various pytest summary formats
TEST_SUMMARY=$(grep -E "^=+ .*(passed|failed|error|skipped|xfailed|xpassed|warning).*=+$" "$PYTEST_OUTPUT" | tail -1)

if [[ -n "$TEST_SUMMARY" ]]; then
# Parse the summary to extract counts
PASSED=$(echo "$TEST_SUMMARY" | grep -oE "[0-9]+ passed" | grep -oE "[0-9]+" || echo "0")
FAILED=$(echo "$TEST_SUMMARY" | grep -oE "[0-9]+ failed" | grep -oE "[0-9]+" || echo "0")
ERRORS=$(echo "$TEST_SUMMARY" | grep -oE "[0-9]+ error" | grep -oE "[0-9]+" || echo "0")
SKIPPED=$(echo "$TEST_SUMMARY" | grep -oE "[0-9]+ skipped" | grep -oE "[0-9]+" || echo "0")
WARNINGS=$(echo "$TEST_SUMMARY" | grep -oE "[0-9]+ warning" | grep -oE "[0-9]+" || echo "0")
XFAILED=$(echo "$TEST_SUMMARY" | grep -oE "[0-9]+ xfailed" | grep -oE "[0-9]+" || echo "0")
XPASSED=$(echo "$TEST_SUMMARY" | grep -oE "[0-9]+ xpassed" | grep -oE "[0-9]+" || echo "0")

# Build detailed summary
DETAILS=""
[[ "$PASSED" != "0" ]] && DETAILS="${DETAILS}Passed: $PASSED"
[[ "$FAILED" != "0" ]] && DETAILS="${DETAILS:+$DETAILS, }Failed: $FAILED"
[[ "$ERRORS" != "0" ]] && DETAILS="${DETAILS:+$DETAILS, }Errors: $ERRORS"
[[ "$SKIPPED" != "0" ]] && DETAILS="${DETAILS:+$DETAILS, }Skipped: $SKIPPED"
[[ "$XFAILED" != "0" ]] && DETAILS="${DETAILS:+$DETAILS, }XFailed: $XFAILED"
[[ "$XPASSED" != "0" ]] && DETAILS="${DETAILS:+$DETAILS, }XPassed: $XPASSED"
[[ "$WARNINGS" != "0" ]] && DETAILS="${DETAILS:+$DETAILS, }Warnings: $WARNINGS"

# Check if there were any failures or errors
if [[ "$FAILED" == "0" && "$ERRORS" == "0" && $TEST_EXIT_CODE -eq 0 ]]; then
echo -e "${GREEN}✅ All tests passed${NC}"
[[ -n "$DETAILS" ]] && echo " $DETAILS"
else
echo -e "${RED}❌ Tests failed${NC}"
[[ -n "$DETAILS" ]] && echo " $DETAILS"
CHECKS_PASSED=false
TESTS_FAILED=true
fi
else
# Fallback if we can't parse the summary
if [[ $TEST_EXIT_CODE -eq 0 ]]; then
echo -e "${GREEN}✅ All unit tests passed${NC}"
else
echo -e "${RED}❌ Some unit tests failed${NC}"
CHECKS_PASSED=false
TESTS_FAILED=true
fi
fi

rm -f "$PYTEST_OUTPUT"
echo

# Check if uv.lock is in sync with pyproject.toml
echo "🔒 Checking if uv.lock is up to date..."
PRIMARY_EXTRAS=(--all-extras)
Expand Down Expand Up @@ -201,9 +262,15 @@ if $CHECKS_PASSED; then
else
echo -e "${RED}❌ Some checks failed${NC}"
if [[ -z "$FIX_FLAG" ]]; then
# Show tips for each type of failure
if $TYPECHECK_FAILED; then
echo -e "💡 Tip: Type errors can't be auto-fixed by --fix. Re-run ${YELLOW}uv run pyright src${NC} to see full diagnostics."
else
fi
if $TESTS_FAILED; then
echo -e "💡 Tip: Test failures can't be auto-fixed by --fix. Re-run ${YELLOW}uv run pytest --nbval tests/unit${NC} to see full test output."
fi
# Show general fix tip if there are failures but not type/test specific ones
if ! $TYPECHECK_FAILED && ! $TESTS_FAILED; then
echo -e "💡 Tip: Run ${YELLOW}./scripts/run_checks.sh --fix${NC} to automatically fix some issues"
fi
fi
Expand Down
14 changes: 7 additions & 7 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,20 @@
from .. import dev
from ..backend import Backend
from ..model import Model, TrainableModel
from ..preprocessing.pack import (
PackedTensors,
packed_tensors_from_tokenized_results,
packed_tensors_to_dir,
plot_packed_tensors,
)
from ..preprocessing.tokenize import tokenize_trajectory_groups
from ..trajectories import Trajectory, TrajectoryGroup
from ..types import Message, TrainConfig
from ..utils import format_message, get_model_step
from .checkpoints import (
delete_checkpoints,
)
from .pack import (
PackedTensors,
packed_tensors_from_tokenized_results,
packed_tensors_to_dir,
plot_packed_tensors,
)
from .service import ModelService
from .tokenize import tokenize_trajectory_groups


class LocalBackend(Backend):
Expand Down
2 changes: 1 addition & 1 deletion src/art/local/service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import AsyncIterator, Protocol, runtime_checkable

from .. import dev, types
from .pack import DiskPackedTensors
from ..preprocessing.pack import DiskPackedTensors


@runtime_checkable
Expand Down
Empty file.
File renamed without changes.
79 changes: 43 additions & 36 deletions src/art/local/tokenize.py → src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,66 +146,73 @@ def tokenize_trajectory(
tools=history.tools, # type: ignore
),
)
sentinal_token_id = max(
sentinal_start_token_id = max(
set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids)
)
sentinal_token = tokenizer.decode(sentinal_token_id)
result = cast(
dict,
sentinal_end_token_id = max(
set(range(cast(int, tokenizer.vocab_size)))
- set(original_token_ids)
- {sentinal_start_token_id}
)
sentinal_start_token = tokenizer.decode(sentinal_start_token_id)
sentinal_end_token = tokenizer.decode(sentinal_end_token_id)
token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(
list[dict],
[
(
message_or_choice
if isinstance(message_or_choice, dict)
and not message_or_choice["role"] == "assistant"
else {
"role": "assistant",
"content": sentinal_token,
"content": f"{sentinal_start_token}{message_or_choice.get('content', None) if isinstance(message_or_choice, dict) else message_or_choice.message.content or ''}{sentinal_end_token}",
}
)
for message_or_choice in messages_and_choices
],
),
tools=history.tools, # type: ignore
return_dict=True,
return_assistant_token_mask=allow_training_without_logprobs,
),
)
token_ids: list[int] = result["input_ids"]
assistant_mask: list[int] = (
result["attention_mask"]
if allow_training_without_logprobs
else [0] * len(token_ids)
)
assistant_mask: list[int] = [0] * len(token_ids)
logprobs = [float("nan")] * len(token_ids)
for message_or_choice in messages_and_choices:
if isinstance(message_or_choice, dict):
continue
choice = message_or_choice
assert choice.logprobs or allow_training_without_logprobs, (
"Chat completion choices must have logprobs"
)
if not choice.logprobs:
continue
token_logprobs = choice.logprobs.content or choice.logprobs.refusal or []
sentinal_index = token_ids.index(sentinal_token_id)
if (
bytes(token_logprobs[0].bytes or []).decode("utf-8")
== "<think>"
== tokenizer.decode(token_ids[sentinal_index - 4])
isinstance(message_or_choice, dict)
and not message_or_choice["role"] == "assistant"
):
start = sentinal_index - 4
continue
start = token_ids.index(sentinal_start_token_id)
end = token_ids.index(sentinal_end_token_id) + 1
if isinstance(message_or_choice, dict):
token_ids[start:end] = token_ids[start + 1 : end - 1]
logprobs[start:end] = [float("nan")] * (end - start - 2)
assistant_mask[start:end] = [1] * (end - start - 2)
else:
start = sentinal_index
end = sentinal_index + 1
token_ids[start:end] = (
int(token_logprob.token.split(":")[1]) for token_logprob in token_logprobs
)
logprobs[start:end] = (
token_logprob.logprob for token_logprob in token_logprobs
)
assistant_mask[start:end] = [1] * len(token_logprobs)
choice = message_or_choice
assert choice.logprobs or allow_training_without_logprobs, (
"Chat completion choices must have logprobs"
)
if not choice.logprobs:
continue
token_logprobs = choice.logprobs.content or choice.logprobs.refusal or []
if (
bytes(token_logprobs[0].bytes or []).decode("utf-8")
== "<think>"
== tokenizer.decode(token_ids[start - 4])
):
start -= 4
token_ids[start:end] = (
int(token_logprob.token.split(":")[1])
for token_logprob in token_logprobs
)
logprobs[start:end] = (
token_logprob.logprob for token_logprob in token_logprobs
)
assistant_mask[start:end] = [1] * len(token_logprobs)
return TokenizedResult(
advantage=advantage,
chat=chat,
Expand Down
2 changes: 1 addition & 1 deletion src/art/torchtune/batch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel

from .. import dev, types
from ..local.pack import DiskPackedTensors
from ..preprocessing.pack import DiskPackedTensors


class Batch(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/art/torchtune/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from tqdm import tqdm

from .. import dev, types
from ..local.pack import PackedTensors, packed_tensors_from_dir
from ..preprocessing.pack import PackedTensors, packed_tensors_from_dir
from .batch import Batch
from .config import (
CompileConfig,
Expand Down
2 changes: 1 addition & 1 deletion src/art/torchtune/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.v1.engine.async_llm import AsyncLLM

from .. import dev, types
from ..local.pack import DiskPackedTensors
from ..preprocessing.pack import DiskPackedTensors
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
from .batch import Batch

Expand Down
6 changes: 5 additions & 1 deletion src/art/unsloth/decoupled_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

from .. import dev, types
from ..local.checkpoints import get_last_checkpoint_dir
from ..local.pack import DiskPackedTensors, PackedTensors, packed_tensors_from_dir
from ..preprocessing.pack import (
DiskPackedTensors,
PackedTensors,
packed_tensors_from_dir,
)
from ..utils.get_model_step import get_step_from_dir
from ..utils.output_dirs import get_step_checkpoint_dir
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
Expand Down
6 changes: 5 additions & 1 deletion src/art/unsloth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from .. import dev, types
from ..local.checkpoints import get_last_checkpoint_dir
from ..local.pack import DiskPackedTensors, PackedTensors, packed_tensors_from_dir
from ..preprocessing.pack import (
DiskPackedTensors,
PackedTensors,
packed_tensors_from_dir,
)
from .train import train

if TYPE_CHECKING:
Expand Down
Loading