From c1ae0bb1c65728445d2878bb2d38193b889c3f3e Mon Sep 17 00:00:00 2001 From: Artem Bolgar Date: Tue, 1 Oct 2024 16:58:21 -0700 Subject: [PATCH] Making DAC decoder torch.compileable The original torch.nn.utils.weight_norm() function is deprecated and it prevents the DAC decoder model from being torch.compiled. The newer torch.nn.utils.parametrizations.weight_norm() works the same way as the old one and plus it is compatible with torch.compile (with CUDA graphs). --- dac/model/discriminator.py | 2 +- dac/nn/layers.py | 2 +- dac/nn/quantize.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dac/model/discriminator.py b/dac/model/discriminator.py index 09c79d1..329a420 100644 --- a/dac/model/discriminator.py +++ b/dac/model/discriminator.py @@ -5,7 +5,7 @@ from audiotools import ml from audiotools import STFTParams from einops import rearrange -from torch.nn.utils import weight_norm +from torch.nn.utils.parametrizations import weight_norm def WNConv1d(*args, **kwargs): diff --git a/dac/nn/layers.py b/dac/nn/layers.py index 44fbc29..80ef3fa 100644 --- a/dac/nn/layers.py +++ b/dac/nn/layers.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from torch.nn.utils import weight_norm +from torch.nn.utils.parametrizations import weight_norm def WNConv1d(*args, **kwargs): diff --git a/dac/nn/quantize.py b/dac/nn/quantize.py index b17ff4a..cc42a6a 100644 --- a/dac/nn/quantize.py +++ b/dac/nn/quantize.py @@ -5,7 +5,6 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from torch.nn.utils import weight_norm from dac.nn.layers import WNConv1d