-
Notifications
You must be signed in to change notification settings - Fork 15
Description
Hi authors, thank you for the great paper and code!
I understand that the training-inference mismatch stems from BF16 rounding errors and kernel discrepancies. However, I am trying to understand the dynamics of why this error grows so significantly during the training process. As shown in Fig 3 in paper, the KL Mismatch explodes from a negligible ≈
I checked the paper and didn't find a section that discussed this phenomenon. Could you clarify if this increase is primarily driven by:
Sequence Length: The model learning to generate longer chains, leading to more errors to accumulate?
Peakiness: The policy becoming more confident (lower entropy) later in training, making the KL divergence hypersensitive to small rounding errors?
Or other reasons?
Any insights on this dynamic would be very helpful. Thanks!