Skip to content

Commit 778c20e

Browse files
committed
[cm] Adding music gen example
1 parent 7c39dbf commit 778c20e

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import logging
2+
import os
3+
from typing import Dict, List
4+
5+
import torch as tr
6+
from torch import Tensor
7+
8+
from neutone_sdk import NeutoneParameter, TextNeutoneParameter, \
9+
CategoricalNeutoneParameter
10+
from neutone_sdk.non_realtime_wrapper import NonRealtimeBase
11+
12+
logging.basicConfig()
13+
log = logging.getLogger(__name__)
14+
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO"))
15+
16+
17+
class MusicGenModelWrapper(NonRealtimeBase):
18+
def get_model_name(self) -> str:
19+
return "MusicGen.example"
20+
21+
def get_model_authors(self) -> List[str]:
22+
return ["Naotake Masuda"]
23+
24+
def get_model_short_description(self) -> str:
25+
return "MusicGen model."
26+
27+
def get_model_long_description(self) -> str:
28+
return "MusicGen model."
29+
30+
def get_technical_description(self) -> str:
31+
return "MusicGen model."
32+
33+
def get_technical_links(self) -> Dict[str, str]:
34+
return {
35+
"Paper": "https://arxiv.org/abs/2306.05284",
36+
"Code": "https://github.com/facebookresearch/audiocraft/"
37+
}
38+
39+
def get_tags(self) -> List[str]:
40+
return ["musicgen"]
41+
42+
def get_model_version(self) -> str:
43+
return "1.0.0"
44+
45+
def is_experimental(self) -> bool:
46+
return False
47+
48+
def get_neutone_parameters(self) -> List[NeutoneParameter]:
49+
return [
50+
TextNeutoneParameter(name="prompt",
51+
description="text prompt for generation",
52+
max_n_chars=256,
53+
default_value="techno kick drum"),
54+
CategoricalNeutoneParameter(name="duration",
55+
description="how many seconds to generate",
56+
n_values=8,
57+
default_value=0,
58+
labels=[str(idx) for idx in range(1, 9)]),
59+
]
60+
61+
@tr.jit.export
62+
def get_audio_in_channels(self) -> List[int]:
63+
return [] # Does not take audio input
64+
65+
@tr.jit.export
66+
def get_audio_out_channels(self) -> List[int]:
67+
return [1] # Mono output
68+
69+
@tr.jit.export
70+
def get_native_sample_rates(self) -> List[int]:
71+
return [32000]
72+
73+
@tr.jit.export
74+
def get_native_buffer_sizes(self) -> List[int]:
75+
return [] # One-shot model so buffer size does not matter
76+
77+
@tr.jit.export
78+
def is_one_shot_model(self) -> bool:
79+
return True
80+
81+
def do_forward_pass(self,
82+
curr_block_idx: int,
83+
audio_in: List[Tensor],
84+
knob_params: Dict[str, Tensor],
85+
text_params: List[str]) -> List[Tensor]:
86+
# The extra cast to int is needed for TorchScript
87+
n_seconds = int(knob_params["duration"].item()) + 1
88+
# Convert duration to number of tokens
89+
n_tokens = (n_seconds * 50) + 4
90+
if self.use_debug_mode:
91+
assert len(text_params) == 1
92+
# TorchScript does not support logging statements
93+
print("Preprocessing...")
94+
# Preprocess
95+
input_ids, encoder_outputs, delay_pattern_mask, encoder_attention_mask = (
96+
self.model.preprocess(text_params, n_tokens)
97+
)
98+
# Generate
99+
for idx in range(n_tokens - 1):
100+
if self.should_cancel_forward_pass():
101+
return []
102+
input_ids = self.model.sample_step(input_ids,
103+
encoder_outputs,
104+
delay_pattern_mask,
105+
encoder_attention_mask)
106+
percentage_progress = int((idx + 1) / n_tokens * 100)
107+
self.set_progress_percentage(percentage_progress)
108+
if self.use_debug_mode:
109+
# TorchScript does not support logging statements
110+
print(f"Generating token {idx + 1}/{n_tokens}...")
111+
print(f"Progress: {self.get_progress_percentage()}%")
112+
if self.use_debug_mode:
113+
# TorchScript does not support logging statements
114+
print("Postprocessing...")
115+
# Postprocess
116+
audio_out = self.model.postprocess(input_ids, delay_pattern_mask, text_params)
117+
# Remove batch dimension
118+
audio_out = audio_out.squeeze(0)
119+
return [audio_out]
120+
121+
122+
if __name__ == "__main__":
123+
import torchtext # This is needed for loading the TorchScript model
124+
# model_path = "../../out/musicgen.ts"
125+
model_path = "/Users/puntland/local_christhetree/qosmo/neutone_sdk/out/musicgen.ts"
126+
model = tr.jit.load(model_path)
127+
wrapper = MusicGenModelWrapper(model)
128+
129+
# TODO(cm): write export method for nonrealtime models
130+
# wrapper.prepare_for_inference()
131+
ts = tr.jit.script(wrapper)
132+
133+
audio_out = wrapper.forward(curr_block_idx=0,
134+
audio_in=[],
135+
numerical_params=tr.tensor([0.0]).unsqueeze(1),
136+
text_params=["testing"])
137+
log.info(audio_out[0].shape)
138+
audio_out = ts.forward(curr_block_idx=0,
139+
audio_in=[],
140+
numerical_params=tr.tensor([0.0]).unsqueeze(1),
141+
text_params=["testing"])
142+
log.info(audio_out[0].shape)

0 commit comments

Comments
 (0)