Skip to content

changing params to go fast on H200s#64

Merged
daanelson merged 6 commits intomainfrom
smart-checkpoint
Jan 23, 2025
Merged

changing params to go fast on H200s#64
daanelson merged 6 commits intomainfrom
smart-checkpoint

Conversation

@daanelson
Copy link
Contributor

@daanelson daanelson commented Jan 23, 2025

In order to speed up training (losslessly!) in default situations on Replicate, I'm making a few tradeoffs for increased speed at the cost of increased memory usage. specifically:

  • turning off gradient checkpointing
  • not quantizing the flux DiT

Important

Optimize training speed on H200s by disabling gradient checkpointing and model quantization under specific conditions in train.py.

  • Behavior:
    • Disable gradient checkpointing by default in train() to increase training speed.
    • Automatically enable gradient checkpointing if GPU memory < 100GB, batch size > 1, or resolution > 1024.
    • Disable model quantization by default; enable if GPU memory < 100GB.
  • Parameters:
    • Add gradient_checkpointing parameter to train() with default False.
    • Set quantize to False by default in train().
  • Misc:
    • Refactor resolution parsing in train() to use resolutions list.

This description was created by Ellipsis for db81504. It will automatically update as commits are pushed.

@daanelson daanelson requested a review from a team January 23, 2025 04:48
@daanelson daanelson merged commit 9e68dd2 into main Jan 23, 2025
2 checks passed
@daanelson daanelson deleted the smart-checkpoint branch January 23, 2025 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant