|
| 1 | +import logging |
| 2 | +import os |
| 3 | +from typing import Dict, List |
| 4 | + |
| 5 | +import torch as tr |
| 6 | +from torch import Tensor |
| 7 | + |
| 8 | +from neutone_sdk import NeutoneParameter, TextNeutoneParameter, \ |
| 9 | + CategoricalNeutoneParameter |
| 10 | +from neutone_sdk.non_realtime_wrapper import NonRealtimeBase |
| 11 | + |
| 12 | +logging.basicConfig() |
| 13 | +log = logging.getLogger(__name__) |
| 14 | +log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) |
| 15 | + |
| 16 | + |
| 17 | +class MusicGenModelWrapper(NonRealtimeBase): |
| 18 | + def get_model_name(self) -> str: |
| 19 | + return "MusicGen.example" |
| 20 | + |
| 21 | + def get_model_authors(self) -> List[str]: |
| 22 | + return ["Naotake Masuda"] |
| 23 | + |
| 24 | + def get_model_short_description(self) -> str: |
| 25 | + return "MusicGen model." |
| 26 | + |
| 27 | + def get_model_long_description(self) -> str: |
| 28 | + return "MusicGen model." |
| 29 | + |
| 30 | + def get_technical_description(self) -> str: |
| 31 | + return "MusicGen model." |
| 32 | + |
| 33 | + def get_technical_links(self) -> Dict[str, str]: |
| 34 | + return { |
| 35 | + "Paper": "https://arxiv.org/abs/2306.05284", |
| 36 | + "Code": "https://github.com/facebookresearch/audiocraft/" |
| 37 | + } |
| 38 | + |
| 39 | + def get_tags(self) -> List[str]: |
| 40 | + return ["musicgen"] |
| 41 | + |
| 42 | + def get_model_version(self) -> str: |
| 43 | + return "1.0.0" |
| 44 | + |
| 45 | + def is_experimental(self) -> bool: |
| 46 | + return False |
| 47 | + |
| 48 | + def get_neutone_parameters(self) -> List[NeutoneParameter]: |
| 49 | + return [ |
| 50 | + TextNeutoneParameter(name="prompt", |
| 51 | + description="text prompt for generation", |
| 52 | + max_n_chars=256, |
| 53 | + default_value="techno kick drum"), |
| 54 | + CategoricalNeutoneParameter(name="duration", |
| 55 | + description="how many seconds to generate", |
| 56 | + n_values=8, |
| 57 | + default_value=0, |
| 58 | + labels=[str(idx) for idx in range(1, 9)]), |
| 59 | + ] |
| 60 | + |
| 61 | + @tr.jit.export |
| 62 | + def get_audio_in_channels(self) -> List[int]: |
| 63 | + return [] # Does not take audio input |
| 64 | + |
| 65 | + @tr.jit.export |
| 66 | + def get_audio_out_channels(self) -> List[int]: |
| 67 | + return [1] # Mono output |
| 68 | + |
| 69 | + @tr.jit.export |
| 70 | + def get_native_sample_rates(self) -> List[int]: |
| 71 | + return [32000] |
| 72 | + |
| 73 | + @tr.jit.export |
| 74 | + def get_native_buffer_sizes(self) -> List[int]: |
| 75 | + return [] # One-shot model so buffer size does not matter |
| 76 | + |
| 77 | + @tr.jit.export |
| 78 | + def is_one_shot_model(self) -> bool: |
| 79 | + return True |
| 80 | + |
| 81 | + def do_forward_pass(self, |
| 82 | + curr_block_idx: int, |
| 83 | + audio_in: List[Tensor], |
| 84 | + knob_params: Dict[str, Tensor], |
| 85 | + text_params: List[str]) -> List[Tensor]: |
| 86 | + # The extra cast to int is needed for TorchScript |
| 87 | + n_seconds = int(knob_params["duration"].item()) + 1 |
| 88 | + # Convert duration to number of tokens |
| 89 | + n_tokens = (n_seconds * 50) + 4 |
| 90 | + if self.use_debug_mode: |
| 91 | + assert len(text_params) == 1 |
| 92 | + # TorchScript does not support logging statements |
| 93 | + print("Preprocessing...") |
| 94 | + # Preprocess |
| 95 | + input_ids, encoder_outputs, delay_pattern_mask, encoder_attention_mask = ( |
| 96 | + self.model.preprocess(text_params, n_tokens) |
| 97 | + ) |
| 98 | + # Generate |
| 99 | + for idx in range(n_tokens - 1): |
| 100 | + if self.should_cancel_forward_pass(): |
| 101 | + return [] |
| 102 | + input_ids = self.model.sample_step(input_ids, |
| 103 | + encoder_outputs, |
| 104 | + delay_pattern_mask, |
| 105 | + encoder_attention_mask) |
| 106 | + percentage_progress = int((idx + 1) / n_tokens * 100) |
| 107 | + self.set_progress_percentage(percentage_progress) |
| 108 | + if self.use_debug_mode: |
| 109 | + # TorchScript does not support logging statements |
| 110 | + print(f"Generating token {idx + 1}/{n_tokens}...") |
| 111 | + print(f"Progress: {self.get_progress_percentage()}%") |
| 112 | + if self.use_debug_mode: |
| 113 | + # TorchScript does not support logging statements |
| 114 | + print("Postprocessing...") |
| 115 | + # Postprocess |
| 116 | + audio_out = self.model.postprocess(input_ids, delay_pattern_mask, text_params) |
| 117 | + # Remove batch dimension |
| 118 | + audio_out = audio_out.squeeze(0) |
| 119 | + return [audio_out] |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == "__main__": |
| 123 | + import torchtext # This is needed for loading the TorchScript model |
| 124 | + # model_path = "../../out/musicgen.ts" |
| 125 | + model_path = "/Users/puntland/local_christhetree/qosmo/neutone_sdk/out/musicgen.ts" |
| 126 | + model = tr.jit.load(model_path) |
| 127 | + wrapper = MusicGenModelWrapper(model) |
| 128 | + |
| 129 | + # TODO(cm): write export method for nonrealtime models |
| 130 | + # wrapper.prepare_for_inference() |
| 131 | + ts = tr.jit.script(wrapper) |
| 132 | + |
| 133 | + audio_out = wrapper.forward(curr_block_idx=0, |
| 134 | + audio_in=[], |
| 135 | + numerical_params=tr.tensor([0.0]).unsqueeze(1), |
| 136 | + text_params=["testing"]) |
| 137 | + log.info(audio_out[0].shape) |
| 138 | + audio_out = ts.forward(curr_block_idx=0, |
| 139 | + audio_in=[], |
| 140 | + numerical_params=tr.tensor([0.0]).unsqueeze(1), |
| 141 | + text_params=["testing"]) |
| 142 | + log.info(audio_out[0].shape) |
0 commit comments