-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Layerwise dynamic upcasting to Diffusers Models #9177
Changes from all commits
be55fa6
6b9fd09
1fdae85
b366b22
9b411e5
f1fa123
0d1a1f8
c64fa22
51a855c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -330,7 +330,7 @@ def decode( | |
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output. | ||
|
||
""" | ||
z = (z * self.config.scaling_factor - self.means) / self.stds | ||
z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully this doesn't lead to performance regressions. |
||
|
||
scale_factor = 2 ** (len(self.config.block_out_channels) - 1) | ||
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -263,6 +263,80 @@ def disable_xformers_memory_efficient_attention(self) -> None: | |||||
""" | ||||||
self.set_use_memory_efficient_attention_xformers(False) | ||||||
|
||||||
def enable_layerwise_upcasting(self, upcast_dtype=None): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay with the name because it conveys the idea easily. Maybe we could note down the implementation detail that it is applied to the leaf nodes in the compute graph in the docstring? |
||||||
r""" | ||||||
Enable layerwise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g. | ||||||
torch.float8_e4m3fn, but perform inference using a dtype that is supported by the GPU, by upcasting the | ||||||
individual modules in the model to the appropriate dtype right before the foward pass. | ||||||
|
||||||
The module is then moved back to the low memory dtype after the foward pass. | ||||||
""" | ||||||
|
||||||
upcast_dtype = upcast_dtype or torch.float32 | ||||||
original_dtype = self.dtype | ||||||
|
||||||
def upcast_dtype_hook_fn(module, *args, **kwargs): | ||||||
module = module.to(upcast_dtype) | ||||||
|
||||||
def cast_to_original_dtype_hook_fn(module, *args, **kwargs): | ||||||
module = module.to(original_dtype) | ||||||
|
||||||
def fn_recursive_upcast(module): | ||||||
"""In certain cases modules will apply casting internally or reference the dtype of internal blocks. | ||||||
|
||||||
e.g. | ||||||
|
||||||
``` | ||||||
class MyModel(nn.Module): | ||||||
def forward(self, x): | ||||||
dtype = next(iter(self.blocks.parameters())).dtype | ||||||
x = self.blocks(x) + torch.ones(x.size()).to(dtype) | ||||||
``` | ||||||
Layerwise upcasting will not work here, since the internal blocks remain in the low memory dtype until | ||||||
their `forward` method is called. We need to add the upcast hook on the entire module in order for the | ||||||
operation to work. | ||||||
|
||||||
The `_always_upcast_modules` class attribute is a list of modules within the model that we must upcast | ||||||
entirely, rather than layerwise. | ||||||
|
||||||
""" | ||||||
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules: | ||||||
# Upcast entire module and exist recursion | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
module.register_forward_pre_hook(upcast_dtype_hook_fn) | ||||||
module.register_forward_hook(cast_to_original_dtype_hook_fn) | ||||||
|
||||||
return | ||||||
|
||||||
has_children = list(module.children()) | ||||||
if not has_children: | ||||||
module.register_forward_pre_hook(upcast_dtype_hook_fn) | ||||||
module.register_forward_hook(cast_to_original_dtype_hook_fn) | ||||||
|
||||||
for child in module.children(): | ||||||
fn_recursive_upcast(child) | ||||||
|
||||||
for module in self.children(): | ||||||
fn_recursive_upcast(module) | ||||||
|
||||||
def disable_layerwise_upcasting(self): | ||||||
def fn_recursive_upcast(module): | ||||||
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules: | ||||||
module._forward_pre_hooks = OrderedDict() | ||||||
module._forward_hooks = OrderedDict() | ||||||
|
||||||
return | ||||||
|
||||||
has_children = list(module.children()) | ||||||
if not has_children: | ||||||
module._forward_pre_hooks = OrderedDict() | ||||||
module._forward_hooks = OrderedDict() | ||||||
|
||||||
for child in module.children(): | ||||||
fn_recursive_upcast(child) | ||||||
|
||||||
for module in self.children(): | ||||||
fn_recursive_upcast(module) | ||||||
|
||||||
def save_pretrained( | ||||||
self, | ||||||
save_directory: Union[str, os.PathLike], | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add
_always_upcast_modules
toModelMixin
?