-
Notifications
You must be signed in to change notification settings - Fork 12
Initial IMM training loss #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi, thanks for your effort! I am curious why you use ema weight for y_r? Have you conducted some experiments and found it worked better? |
@XinYu-Andy thank you! Using ema weights for I updated the pr with this change for now! |
Are you doing experiments on cifar10? I conducted the experiment for a few weeks but was still not able to reproduce the results reported in the paper... |
Yes, with the DDPM++ UNet, using all the same reported hyperparams. I just started today and haven't extensively experimented yet, but would like to see a stable loss before anything further. |
Sounds good!👍 |
|
@jiamings @alexzhou907 @karanganesan @manskx can we get an update on the training code for IMM? I've had 5-10 different people/labs message me personally asking and stating that this work is not reproducible. This pr is an effort toward reproducing the work, but myself and the research community must be missing something. |
|
Hi @stockeh. The plan is to release the code around ICML time, or slightly sooner than that. We'd appreciate the patience. For reproduction, there are a lot more implementation details in appendix in the latest version you can check out. Some important details mentioned include keeping precision in TF32 or FP16 and we discourage using BF16 due to closeness between r and t. Another typo I made initially in the paper is the kernel weighting should be 1 / |c_out| instead of 1/c_out. |

Initial attempt at the training loss defined in the paper (assumes constant decrement in$\eta(t)$ for mapping function).