Skip to content

Latest commit

 

History

History
83 lines (47 loc) · 4.06 KB

README.md

File metadata and controls

83 lines (47 loc) · 4.06 KB

[MASK] is All You Need

This repository represents the official implementation of the paper titled "[MASK] is All You Need".

Website Paper Hugging Face Model GitHub GitHub closed issues License visitors

Vincent Tao Hu, Björn Ommer

TLDR

We present Discrete Interpolants, to bridge the Diffusion Models and Maskged Generative Models in discrete-state, and scale it up in vision domain.

teaser

🎓 Citation

Please cite our paper:

@InProceedings{hu2024mask,
      title={[MASK] is All You Need},
      author={Vincent Tao Hu and Björn Ommer},
      booktitle = {Arxiv},
      year={2024}
}

✅ Updates

  • Feb. 4th, 2025: Training code released.
  • Dec. 10th, 2024: Arxiv released.

📦 Training

COCO training(Deepspeed)

CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch  --num_processes 4 --num_machines 1 --main_process_ip 127.0.0.1 --main_process_port 8868    train_ds_vq.py   model=uvit_s2deep_it data=coco14_cond_indices dynamic=linear dynamic.mask_ce=1  input_tensor_type=bwh tokenizer=sd_vq_f8 optim.wd=0.00 "optim.betas=[0.9, 0.9]" data.train_steps=1_000_000 ckpt_every=20_000 data.sample_fid_every=100_000 data.sample_fid_n=20_000   data.batch_size=64 optim.name=adam optim.lr=2e-4 lrschedule.warmup_steps=5000 dstep_num=500  mixed_precision=bf16 accum=4

ImageNet training(accelerator,bs256)

CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch   --num_processes 4 --num_machines 1 --main_process_ip 127.0.0.1 --main_process_port 8868    train_acc_vq.py  model=uvit_h2_it dynamic=linear   input_tensor_type=bwh tokenizer=sd_vq_f8 data=imagenet256_cond_indices data.batch_size=64 data.sample_vis_n=16 data.sample_fid_every=50_000 ckpt_every=20_000 data.train_steps=1500_000  data.sample_fid_n=5_000 optim.name=adamw optim.lr=1e-4 optim.wd=0.0 lrschedule.warmup_steps=1     mixed_precision=bf16 accum=1

Evaluation

CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --num_processes 4 --num_machines 1 --main_process_ip 127.0.0.1 --main_process_port 8868 sample_ds_vq.py model=dit_xl2_it dynamic=linear input_tensor_type=bwh tokenizer=sd_vq_f8 data=imagenet256_cond_indices data.batch_size=64 data.sample_vis_n=16 data.sample_fid_every=40_000 data.sample_fid_n=5_000 optim.name=adamw optim.lr=1e-4 optim.wd=0.0 lrschedule.warmup_steps=0 data.train_steps=1_400_000 ckpt_every=20_000 mixed_precision=bf16 accum=1 ckpt="./outputs/v1.3_vqacc_note_bf16_imagenet256_cond_indices_dit_xl2_it_linear_sd_vq_f8_bs64acc1_wd0.0_gc1.0_4g_mcml-hgx-h100-008_4980788/2024-12-08_11-58-30/checkpoints/1100000.pt" num_fid_samples=50000 offline.lbs=100 dynamic.disint.scheduler=linear dynamic.disint.sampler=maskgit maskgit_randomize=linear top_k=0 top_p=0  offline.save_samples_to_disk=1 sm_t=1.3  use_cfg=1 cfg_scale=2 dstep_num=20

You should get an FID around 8.26.

Trend

Star History Chart

🎫 License

This work is licensed under the Apache License, Version 2.0 (as defined in the LICENSE).

By downloading and using the code and model you agree to the terms in the LICENSE.

License