diff --git a/nanovllm/layers/layernorm.py b/nanovllm/layers/layernorm.py index 71bf4198f..a969fd903 100755 --- a/nanovllm/layers/layernorm.py +++ b/nanovllm/layers/layernorm.py @@ -21,7 +21,7 @@ def rms_forward( orig_dtype = x.dtype x = x.float() var = x.pow(2).mean(dim=-1, keepdim=True) - x.mul_(torch.rsqrt(var + self.eps)) + x = x * torch.rsqrt(var + self.eps) x = x.to(orig_dtype).mul_(self.weight) return x diff --git a/test/test_layernorm.py b/test/test_layernorm.py new file mode 100644 index 000000000..c4eeba986 --- /dev/null +++ b/test/test_layernorm.py @@ -0,0 +1,23 @@ +import importlib.util +from pathlib import Path + +import torch + + +spec = importlib.util.spec_from_file_location( + "layernorm", + Path(__file__).parents[1] / "nanovllm" / "layers" / "layernorm.py", +) +layernorm = importlib.util.module_from_spec(spec) +spec.loader.exec_module(layernorm) +RMSNorm = layernorm.RMSNorm + + +def test_rms_norm_does_not_modify_fp32_input(): + norm = RMSNorm(4) + x = torch.randn(2, 4, dtype=torch.float32) + expected = x.clone() + + norm(x) + + torch.testing.assert_close(x, expected)