This is a clean, non-official implementation of Inductive Moment Matching (IMM) [1] using PyTorch and Lightning ⚡. It contains most (if not all) features described the paper.
I wrote it as the paper came out, in March, and thought I should share it. The implementation is efficient and works for data types of various dimensions. I included in train.py an example running on MNIST with a small (5.2M parameters) model.
I haven't really updated the code since then, so some comments might be out of date, as the original paper contained one or two (minor) typos. Similarly, the architecture I used is the so-called SongUnet, which I had quickly updated to include the second time parameter; but you should likely instead use the architectures from the true IMM repo.
Enjoy! :)
You can install the requirements via pip with
pip install -r requirements.txtIt should work with other installations, but those are the dependencies with which the code was tried out.
To run the MNIST example, simply run
python train.pyIt will log the run on wandb.
- [1]: Linqi Zhou, Stefano Ermon, Jiaming Song, Inductive Moment Matching (2025). arXiv: arxiv.org/abs/2503.07565.