Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions miles/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,20 @@ def reset_for_retry(self) -> None:

@property
def oldest_weight_version(self) -> int | None:
"""Minimum weight version across all turns (generation calls) for this trajectory."""
"""Minimum weight version across turns.

Non-numeric versions are ignored.
"""
if not self.weight_versions:
return None
return min(int(v) for v in self.weight_versions)

versions = []
for version in self.weight_versions:
try:
versions.append(int(version))
except (TypeError, ValueError):
continue
return min(versions) if versions else None

def update_from_meta_info(self, args, meta_info: dict):
"""
Expand Down
10 changes: 10 additions & 0 deletions tests/fast/utils/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,13 @@ def test_strip_negative_is_noop(self, tokenizer):
original_tokens = list(s.tokens)
s.strip_last_output_tokens(-1, tokenizer)
assert s.tokens == original_tokens


class TestOldestWeightVersion:
def test_ignores_non_numeric_versions(self):
s = Sample(weight_versions=["default", "3", "x", "10"])
assert s.oldest_weight_version == 3

def test_all_non_numeric_versions_return_none(self):
s = Sample(weight_versions=["default", "latest"])
assert s.oldest_weight_version is None
Loading