diff --git a/nanovllm/layers/rotary_embedding.py b/nanovllm/layers/rotary_embedding.py index 998d11646..489afa787 100644 --- a/nanovllm/layers/rotary_embedding.py +++ b/nanovllm/layers/rotary_embedding.py @@ -48,14 +48,40 @@ def forward( return query, key +def _normalize_rope_scaling( + rope_scaling: dict | None, +) -> None: + if rope_scaling is None: + return None + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type")) + if rope_type == "default": + return None + if rope_type is None and set(rope_scaling).issubset({"rope_theta"}): + return None + raise NotImplementedError( + f"Unsupported rope_scaling={rope_scaling!r}. " + "nano-vllm only supports default RoPE without scaling." + ) + + @lru_cache(1) -def get_rope( +def _get_rope( head_size: int, rotary_dim: int, max_position: int, base: float, - rope_scaling: dict | None = None, + rope_scaling: None = None, ): - assert rope_scaling is None rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) return rotary_emb + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: float, + rope_scaling: dict | None = None, +): + rope_scaling = _normalize_rope_scaling(rope_scaling) + return _get_rope(head_size, rotary_dim, max_position, base, rope_scaling) diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index 5d39e0b90..71c3b6e5d 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -23,7 +23,7 @@ def __init__( rms_norm_eps: float = 1e-06, qkv_bias: bool = False, rope_theta: float = 10000, - rope_scaling: tuple | None = None, + rope_scaling: dict | None = None, ) -> None: super().__init__() tp_size = dist.get_world_size()