Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,17 @@ repos:
hooks:
- id: uv-lock

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.1
- repo: local
hooks:
- id: mypy
name: mypy
additional_dependencies:
- uv==0.6.2
- mypy
- pytest
- rich
- types-requests
- pydantic
- uv==0.7.8
entry: uv run --group dev --group type_checking mypy
language: python
types: [python]
pass_filenames: false
require_serial: true

# - repo: https://github.com/tcort/markdown-link-check
# rev: v3.11.2
Expand Down
39 changes: 36 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,38 @@ dev = [
"black",
"ruff",
"mypy",
"pre-commit",
"ruamel.yaml", # needed for openapi generator
]
# Type checking dependencies - includes type stubs and optional runtime dependencies
# needed for complete mypy coverage across all optional features
type_checking = [
"types-requests",
"types-setuptools",
"types-jsonschema",
"pandas-stubs",
"types-psutil",
"types-tqdm",
"boto3-stubs[s3]",
"pre-commit",
"ruamel.yaml", # needed for openapi generator
"streamlit",
"streamlit-option-menu",
"pandas",
"anthropic",
"databricks-sdk",
"fairscale",
"torchtune",
"trl",
"peft",
"datasets",
"together",
"nest-asyncio",
"pymongo",
"torchvision",
"sqlite-vec",
"faiss-cpu",
"lm-format-enforcer",
"mcp",
"ollama",
]
# These are the dependencies required for running unit tests.
unit = [
Expand Down Expand Up @@ -322,7 +345,17 @@ exclude = [

[[tool.mypy.overrides]]
# packages that lack typing annotations, do not have stubs, or are unavailable.
module = ["yaml", "fire"]
module = [
"yaml",
"fire",
"torchtune.*",
"fairscale.*",
"torchvision.*",
"datasets",
"nest_asyncio",
"streamlit_option_menu",
"lmformatenforcer.*",
]
ignore_missing_imports = true

[tool.pydantic-mypy]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from trl import SFTConfig, SFTTrainer
Expand All @@ -32,6 +31,7 @@

from ..config import HuggingFacePostTrainingConfig
from ..utils import (
HFAutoModel,
calculate_training_steps,
create_checkpoints,
get_memory_stats,
Expand Down Expand Up @@ -338,7 +338,7 @@ def setup_training_args(

def save_model(
self,
model_obj: AutoModelForCausalLM,
model_obj: HFAutoModel,
trainer: SFTTrainer,
peft_config: LoraConfig | None,
output_dir_path: Path,
Expand All @@ -350,14 +350,22 @@ def save_model(
peft_config: Optional LoRA configuration
output_dir_path: Path to save the model
"""
from typing import cast

logger.info("Saving final model")
model_obj.config.use_cache = True

if peft_config:
logger.info("Merging LoRA weights with base model")
model_obj = trainer.model.merge_and_unload()
# TRL's merge_and_unload returns a HuggingFace model
# Both cast() and type: ignore are needed here:
# - cast() tells mypy the return type is HFAutoModel for downstream code
# - type: ignore suppresses errors on the merge_and_unload() call itself,
# which mypy can't type-check due to TRL library's incomplete type stubs
model_obj = cast(HFAutoModel, trainer.model.merge_and_unload()) # type: ignore[union-attr,operator]
else:
model_obj = trainer.model
# trainer.model is the trained HuggingFace model
model_obj = cast(HFAutoModel, trainer.model)

save_path = output_dir_path / "merged_model"
logger.info(f"Saving model to {save_path}")
Expand Down Expand Up @@ -411,7 +419,7 @@ async def _run_training(
# Initialize trainer
logger.info("Initializing SFTTrainer")
trainer = SFTTrainer(
model=model_obj,
model=model_obj, # type: ignore[arg-type]
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def setup_training_args(
save_total_limit=provider_config.save_total_limit,
# DPO specific parameters
beta=dpo_config.beta,
loss_type=provider_config.dpo_loss_type,
loss_type=provider_config.dpo_loss_type, # type: ignore[arg-type]
)

def save_model(
Expand Down Expand Up @@ -381,13 +381,16 @@ async def _run_training(

# Initialize DPO trainer
logger.info("Initializing DPOTrainer")
# TRL library has incomplete type stubs - use Any to bypass
from typing import Any, cast

trainer = DPOTrainer(
model=model_obj,
ref_model=ref_model,
model=cast(Any, model_obj), # HFAutoModel satisfies PreTrainedModel protocol
ref_model=cast(Any, ref_model),
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
processing_class=cast(Any, tokenizer), # AutoTokenizer satisfies interface
)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,31 @@
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any, Protocol

import psutil
import torch
from datasets import Dataset
from transformers import AutoConfig, AutoModelForCausalLM

if TYPE_CHECKING:
from transformers import PretrainedConfig


class HFAutoModel(Protocol):
"""Protocol describing HuggingFace AutoModel interface.

This protocol defines the common interface for HuggingFace AutoModelForCausalLM
and similar models, providing type safety without requiring type stubs.
"""

config: PretrainedConfig
device: torch.device

def to(self, device: torch.device) -> "HFAutoModel": ...
def save_pretrained(self, save_directory: str | Path) -> None: ...


from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
Expand Down Expand Up @@ -132,7 +150,7 @@ def load_model(
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
) -> HFAutoModel:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
Expand All @@ -143,6 +161,8 @@ def load_model(
Raises:
RuntimeError: If model loading fails
"""
from typing import cast

logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
Expand All @@ -154,9 +174,10 @@ def load_model(
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
model_obj = model_obj.to(device) # type: ignore[arg-type]
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
# Cast to HFAutoModel protocol - transformers models satisfy this interface
return cast(HFAutoModel, model_obj)
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def setup(self) -> None:
log.info("Optimizer is initialized.")

self._loss_fn = CEWithChunkedOutputLoss()
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) # type: ignore[operator]
log.info("Loss is initialized.")

assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
Expand Down Expand Up @@ -284,7 +284,7 @@ async def _setup_model(
if self._is_dora:
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
m.initialize_dora_magnitude() # type: ignore[operator]
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
else:
Expand Down Expand Up @@ -353,7 +353,7 @@ async def fetch_rows(dataset_id: str):
dataset_type=self._data_format.value,
)

sampler = DistributedSampler(
sampler: DistributedSampler = DistributedSampler(
ds,
num_replicas=1,
rank=0,
Expand Down Expand Up @@ -389,7 +389,7 @@ async def _setup_lr_scheduler(
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
return lr_scheduler
return lr_scheduler # type: ignore[no-any-return]

async def save_checkpoint(self, epoch: int) -> str:
ckpt_dict = {}
Expand Down Expand Up @@ -447,7 +447,7 @@ async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
# free logits otherwise it peaks backward memory
del logits

return loss
return loss # type: ignore[no-any-return]

async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
from typing import Any

import faiss
import faiss # type: ignore[import-untyped]
import numpy as np
from numpy.typing import NDArray

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any

import numpy as np
import sqlite_vec
import sqlite_vec # type: ignore[import-untyped]
from numpy.typing import NDArray

from llama_stack.apis.common.errors import VectorStoreNotFoundError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"

async def list_provider_model_ids(self) -> Iterable[str]:
# Filter out None values from endpoint names
return [
endpoint.name
endpoint.name # type: ignore[misc]
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from collections.abc import Iterable
from typing import Any, cast

from together import AsyncTogether
from together.constants import BASE_URL
from together import AsyncTogether # type: ignore[import-untyped]
from together.constants import BASE_URL # type: ignore[import-untyped]

from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
Expand Down
6 changes: 5 additions & 1 deletion src/llama_stack/testing/api_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,11 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
if endpoint == "/api/tags":
from ollama import ListResponse

body = ListResponse(models=ordered)
# Both cast(Any, ...) and type: ignore are needed here:
# - cast(Any, ...) attempts to bypass type checking on the argument
# - type: ignore is still needed because mypy checks the call site independently
# and reports arg-type mismatch even after casting
body = ListResponse(models=cast(Any, ordered)) # type: ignore[arg-type]
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}


Expand Down
Loading
Loading