Skip to content

Fix RMSNorm fp32 input mutation#205

Open
JxKim wants to merge 1 commit into
GeeeekExplorer:mainfrom
JxKim:fix-rmsnorm-fp32-input-from-2f21442
Open

Fix RMSNorm fp32 input mutation#205
JxKim wants to merge 1 commit into
GeeeekExplorer:mainfrom
JxKim:fix-rmsnorm-fp32-input-from-2f21442

Conversation

@JxKim

@JxKim JxKim commented Apr 14, 2026

Copy link
Copy Markdown

Fixes #170.

RMSNorm.rms_forward used an in-place mul_ after x.float(). When the input is already fp32, x.float() returns the original tensor, so the in-place normalization mutates hidden_states and corrupts the residual path.

This changes the scaling step to allocate a normalized tensor instead of mutating the input, and adds a regression test for fp32 inputs.

Testing:

  • Attempted: uv run python -m pytest test/test_layernorm.py
  • Blocked locally by Windows triton wheel incompatibility, WSL flash-attn build isolation, and later CPU torch dependency download timeout.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant