Skip to content

Commit

Permalink
Fix MusicGen Stereo MultiBandDiffusion (#276)
Browse files Browse the repository at this point in the history
* add MBD stereo handler
* fix flac plugins
* upgrade nodejs on colab
  • Loading branch information
rsxdalv authored Feb 7, 2024
1 parent fd08af9 commit 13e1a95
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ https://rsxdalv.github.io/bark-speaker-directory/
https://github.com/rsxdalv/tts-generation-webui/discussions/186#discussioncomment-7291274

## Changelog
Feb 8:
* Fix MultiBandDiffusion for MusicGen's stereo models, thank you https://github.com/mykeehu
* Fix Node.js installation steps on Google Colab, code by https://github.com/miaohf

Feb 6:
* Add FLAC file generation extension by https://github.com/JoaCHIP

Expand Down
16 changes: 15 additions & 1 deletion notebooks/google_collab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@
"os.chdir(\"./tts-generation-webui\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1c0ce06",
"metadata": {},
"outputs": [],
"source": [
"# Get latest Node.js\n",
"!wget https://nodejs.org/dist/v21.6.0/node-v21.6.0-linux-x64.tar.gz \n",
"!tar xvfz node-v21.6.0-linux-x64.tar.gz \n",
"!cp -r node-v21.6.0-linux-x64/* /usr/local/"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -82,7 +95,8 @@
},
"outputs": [],
"source": [
"!python server.py --share"
"!python server.py --share\n",
"# Note - Node.js/React UI works but isn't accesssible by default on Google Colab, only gradio is easy to open."
]
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def double_escape_backslash(prompt):
metadata_str = json.dumps(metadata, ensure_ascii=False)

pipe_input = ffmpeg.input("pipe:", format="f32le", ar=str(SAMPLE_RATE))
metadata_filename = files.get("flac") + ".ffmetadata.ini" # type: ignore
metadata_filename = filename + ".ffmetadata.ini" # type: ignore
with open(metadata_filename, "w", encoding="utf-8") as f:
f.write(
f""";FFMETADATA1
Expand Down Expand Up @@ -99,9 +99,9 @@ def remove_map_1(args: List[str]) -> List[str]:
# print(p.returncode)
# Show if success
if p.returncode == 0:
print("Saved generation to", files.get("flac"))
print("Saved generation to", filename)
else:
print("Failed to save generation to", files.get("flac"))
print("Failed to save generation to", filename)
print("ffmpeg args:", args)
print(output_data[0])
# print(output_data[1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def callback_save_generation_musicgen(
channels = audio_array.shape[1] if len(audio_array.shape) > 1 else 1
pipe_input = ffmpeg.input("pipe:", format="f32le", ar=str(SAMPLE_RATE), ac=channels)
# TODO: test with Tempfile
metadata_filename = files.get("flac") + ".ffmetadata.ini" # type: ignore
metadata_filename = filename + ".ffmetadata.ini" # type: ignore
with open(metadata_filename, "w", encoding="utf-8") as f:
f.write(
f""";FFMETADATA1
Expand Down Expand Up @@ -74,9 +74,9 @@ def remove_map_1(args: List[str]) -> List[str]:
# print(p.returncode)
# Show if success
if p.returncode == 0:
print("Saved generation to", files.get("flac"))
print("Saved generation to", filename)
else:
print("Failed to save generation to", files.get("flac"))
print("Failed to save generation to", filename)
print("ffmpeg args:", args)
print(output_data[0])
# print(output_data[1])
Expand Down
12 changes: 10 additions & 2 deletions src/musicgen/musicgen_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import gradio as gr
from audiocraft.models.musicgen import MusicGen
from audiocraft.models.audiogen import AudioGen
from audiocraft.models.encodec import InterleaveStereoCompressionModel
from einops import rearrange
from typing import Optional, Tuple, TypedDict
import numpy as np
import os
Expand Down Expand Up @@ -207,8 +209,14 @@ def generate(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarr
from audiocraft.models.multibanddiffusion import MultiBandDiffusion

mbd = MultiBandDiffusion.get_mbd_musicgen()
wav_diffusion = mbd.tokens_to_wav(tokens)
output = wav_diffusion.detach().cpu().numpy().squeeze()
if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
left, right = MODEL.compression_model.get_left_right_codes(tokens)
tokens = torch.cat([left, right])
outputs_diffusion = mbd.tokens_to_wav(tokens)
if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
assert outputs_diffusion.shape[1] == 1 # output is mono
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
output = outputs_diffusion.detach().cpu().numpy().squeeze()
else:
print("NOTICE: Multi-band diffusion is not supported for AudioGen")
params["use_multi_band_diffusion"] = False
Expand Down

0 comments on commit 13e1a95

Please sign in to comment.