diff --git a/bark/generation.py b/bark/generation.py index 54f98709..3a15b8b2 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -92,7 +92,7 @@ def _cast_bool_env_var(s): USE_SMALL_MODELS = _cast_bool_env_var(os.environ.get("SUNO_USE_SMALL_MODELS", "False")) GLOBAL_ENABLE_MPS = _cast_bool_env_var(os.environ.get("SUNO_ENABLE_MPS", "False")) OFFLOAD_CPU = _cast_bool_env_var(os.environ.get("SUNO_OFFLOAD_CPU", "False")) - +DISABLE_COMPILE = _cast_bool_env_var(os.environ.get("SUNO_DISABLE_COMPILE", "False")) REMOTE_MODEL_PATHS = { "text_small": { @@ -254,6 +254,11 @@ def _load_codec_model(device): model.set_target_bandwidth(6.0) model.eval() model.to(device) + if callable(getattr(torch, "compile")) and not DISABLE_COMPILE: + logger.info("torch.compile available, compiling codec model.") + model = torch.compile(model) + else: + logger.info("torch.compile *not* available, you will get better performance if you use pytorch >= 2.0.") _clear_cuda_cache() return model