diff --git a/.gitignore b/.gitignore index 2e61f4c4..79b1d750 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ config.json /voices-tortoise/ # Ignore model checkpoints -/data +/data/models # Ignore temporary files temp/ diff --git a/README.md b/README.md index da4b42d4..ddfe29c6 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,11 @@ https://rsxdalv.github.io/bark-speaker-directory/ ## Info about managing models, caches and system space for AI projects https://github.com/rsxdalv/tts-generation-webui/discussions/186#discussioncomment-7291274 -## Changelo +## Changelog +Jan 16: +* Upgrade MusicGen, adding support for stereo and large melody models +* Add MAGNeT + Jan 15: * Upgraded Gradio to 3.48.0 * Several visual bugs have appeared, if they are critical, please report them or downgrade gradio. diff --git a/data/models/magnet/.gitkeep b/data/models/magnet/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/react-ui/package.json b/react-ui/package.json index 52741c61..ae7ebd6d 100644 --- a/react-ui/package.json +++ b/react-ui/package.json @@ -3,7 +3,7 @@ "version": "0.1.0", "private": true, "scripts": { - "dev": "next dev", + "dev": "next dev -p 3001", "build": "next build", "start": "next start", "lint": "next lint", diff --git a/react-ui/src/components/Header.tsx b/react-ui/src/components/Header.tsx index f50c3412..091566d3 100644 --- a/react-ui/src/components/Header.tsx +++ b/react-ui/src/components/Header.tsx @@ -70,7 +70,11 @@ const routes: Route[] = [ }, { href: "/musicgen", - text: "Musicgen", + text: "MusicGen", + }, + { + href: "/magnet", + text: "MAGNeT", }, { href: "/demucs", diff --git a/react-ui/src/pages/api/gradio/[name].tsx b/react-ui/src/pages/api/gradio/[name].tsx index f3e550ac..6cdb5983 100644 --- a/react-ui/src/pages/api/gradio/[name].tsx +++ b/react-ui/src/pages/api/gradio/[name].tsx @@ -221,7 +221,7 @@ async function reload_old_generation_dropdown() { ]; }; - return result?.data[0].choices; + return result?.data[0].choices.map(x => x[0]); } async function bark_favorite({ history_bundle_name_data }) { @@ -315,7 +315,7 @@ async function tortoise_refresh_models() { ]; }; - return result?.data[0].choices; + return result?.data[0].choices.map(x => x[0]) } async function tortoise_refresh_voices() { @@ -330,7 +330,7 @@ async function tortoise_refresh_voices() { ]; }; - return result?.data[0].choices; + return result?.data[0].choices.map(x => x[0]) } async function tortoise_open_models() { @@ -440,7 +440,7 @@ async function rvc_model_reload() { ]; }; - return result?.data[0].choices; + return result?.data[0].choices.map(x => x[0]); } async function rvc_index_reload() { @@ -455,7 +455,7 @@ async function rvc_index_reload() { ]; }; - return result?.data[0].choices; + return result?.data[0].choices.map(x => x[0]); } // rvc_model_open @@ -607,9 +607,91 @@ async function get_config_bark() { }; } +async function magnet({ + model, + text, + seed, + use_sampling, + top_k, + top_p, + temperature, + max_cfg_coef, + min_cfg_coef, + decoding_steps_1, + decoding_steps_2, + decoding_steps_3, + decoding_steps_4, + span_arrangement, +}) { + const app = await getClient(); + + const result = (await app.predict("/magnet", [ + model, + text, + seed, + use_sampling, + top_k, + top_p, + temperature, + max_cfg_coef, + min_cfg_coef, + decoding_steps_1, + decoding_steps_2, + decoding_steps_3, + decoding_steps_4, + span_arrangement, + ])) as { + data: [ + GradioFile, // output + string, // history_bundle_name_data + string, // image + null, // seed_cache + Object // result_json + ]; + }; + + const [audio, history_bundle_name_data, , , json] = result?.data; + return { + audio, + history_bundle_name_data, + json, + }; +} + +// magnet_get_models + +async function magnet_get_models() { + const app = await getClient(); + + const result = (await app.predict("/magnet_get_models")) as { + data: [ + { + choices: string[]; + __type__: "update"; + } + ]; + }; + + return result?.data[0].choices.map(x => x[0]); +} + +// magnet_open_model_dir + +async function magnet_open_model_dir() { + const app = await getClient(); + + const result = (await app.predict("/magnet_open_model_dir")) as {}; + + return result; +} + const endpoints = { demucs, musicgen, + magnet, + magnet_get_models, + magnet_open_model_dir, + vocos_wav, vocos_npz, encodec_decode, @@ -619,12 +701,14 @@ const endpoints = { reload_old_generation_dropdown, bark_favorite, delete_generation, + tortoise, tortoise_refresh_models, tortoise_refresh_voices, tortoise_open_models, tortoise_open_voices, tortoise_apply_model_settings, + rvc, rvc_model_reload, rvc_index_reload, diff --git a/react-ui/src/pages/magnet.tsx b/react-ui/src/pages/magnet.tsx new file mode 100644 index 00000000..74502a36 --- /dev/null +++ b/react-ui/src/pages/magnet.tsx @@ -0,0 +1,679 @@ +import React from "react"; +import { Template } from "../components/Template"; +import Head from "next/head"; +import useLocalStorage from "../hooks/useLocalStorage"; +import { AudioInput, AudioOutput } from "../components/AudioComponents"; +import { + MagnetParams, + initialMagnetParams, + magnetId, +} from "../tabs/MagnetParams"; +import { GradioFile } from "../types/GradioFile"; + +type AudioOutput = { + name: string; + data: string; + size?: number; + is_file?: boolean; + orig_name?: string; + type_name?: string; +}; + +type Result = { + audio: GradioFile; + history_bundle_name_data: string; + json: { + _version: string; + _hash_version: string; + _type: string; + _audiocraft_version: string; + models: {}; + prompt: string; + hash: string; + date: string; + model: string; + text: string; + seed: string; + use_sampling: boolean; + top_k: number; + top_p: number; + temperature: number; + max_cfg_coef: number; + min_cfg_coef: number; + decoding_steps: number[]; + span_arrangement: string; + }; +}; + +const modelMap = { + Small: { size: "small", 10: true, 30: true }, + Medium: { size: "medium", 10: true, 30: true }, + Audio: { size: "audio", 10: true, 30: false }, +}; + +const canUseDuration = (type: string, isAudio: boolean, duration: string) => { + const subType = isAudio ? "Audio" : type; + const { [duration]: canUse } = modelMap[subType]; + return canUse; +}; + +const modelToType = { + "facebook/magnet-small-10secs": "Small", + "facebook/magnet-medium-10secs": "Medium", + "facebook/magnet-small-30secs": "Small", + "facebook/magnet-medium-30secs": "Medium", + "facebook/audio-magnet-small": "Small", + "facebook/audio-magnet-medium": "Medium", +}; + +const computeModel = (type: string, isAudio: boolean, duration: number) => { + const lowerType = type.toLowerCase(); + const durationSuffix = duration === 30 ? "-30secs" : "-10secs"; + + return isAudio + ? `facebook/audio-magnet-${lowerType}` + : `facebook/magnet-${lowerType}${durationSuffix}`; +}; + +const getType = (model: string) => { + return modelToType[model] || "Small"; +}; + +const decomputeModel = ( + model: string +): { type: string; isAudio: boolean; duration: number } => { + const type = getType(model); + const duration = model.includes("-30secs") ? 30 : 10; + const isAudio = model.includes("audio"); + return { type, isAudio, duration }; +}; + +const ModelSelector = ({ + magnetParams, + setMagnetParams, +}: { + magnetParams: MagnetParams; + setMagnetParams: React.Dispatch>; +}) => { + const { + type: modelType, + isAudio, + duration, + } = decomputeModel(magnetParams.model); + + return ( +
+
Model:
+ + setMagnetParams({ + ...magnetParams, + model: event.target.value, + }) + } + /> +
+ +
+ {["Small", "Medium"].map((type) => ( +
+ + setMagnetParams({ + ...magnetParams, + model: computeModel(event.target.value, isAudio, duration), + }) + } + className="border border-gray-300 p-2 rounded" + /> + +
+ ))} +
+
+ {/*
+ + + setMagnetParams({ + ...magnetParams, + model: computeModel(modelType, event.target.checked, duration), + }) + } + className="border border-gray-300 p-2 rounded" + /> +
*/} + {/* Instead of a checkbox make it a radio between Audio and Music */} +
+ +
+ {["Music", "Audio"].map((type) => ( +
+ + setMagnetParams({ + ...magnetParams, + model: computeModel( + modelType, + event.target.value === "Audio", + duration + ), + }) + } + className="border border-gray-300 p-2 rounded" + /> + +
+ ))} +
+
+
+ +
+ {["10", "30"].map((d) => ( +
+ + setMagnetParams({ + ...magnetParams, + model: computeModel( + modelType, + isAudio, + Number(event.target.value) + ), + }) + } + className="border border-gray-300 p-2 rounded" + disabled={!canUseDuration(modelType, isAudio, d)} + /> + +
+ ))} +
+
+
+ ); +}; + +const Model = ({ + params, + handleChange, +}: { + params: MagnetParams; + handleChange: ( + event: + | React.ChangeEvent + | React.ChangeEvent + | React.ChangeEvent + ) => void; +}) => { + const [options, setOptions] = React.useState([]); + const [loading, setLoading] = React.useState(false); + + const fetchOptions = async () => { + setLoading(true); + const response = await fetch("/api/gradio/magnet_get_models", { + method: "POST", + }); + + const result = await response.json(); + setOptions(result); + setLoading(false); + }; + + const openModels = async () => { + await fetch("/api/gradio/magnet_open_model_dir", { + method: "POST", + }); + }; + + React.useEffect(() => { + fetchOptions(); + }, []); + + const selected = params?.model; + return ( +
+
+ + + +
+
+ ); +}; + +const SeedInput = ({ + magnetParams, + handleChange, + setMagnetParams, + seed, +}: { + magnetParams: MagnetParams; + handleChange: (event: React.ChangeEvent) => void; + setMagnetParams: React.Dispatch>; + seed: number | string | undefined; +}) => ( + <> + + + + +); + +const initialHistory = []; // prevent infinite loop +const MagnetPage = () => { + const [data, setData] = useLocalStorage( + "magnetGenerationOutput", + null + ); + const [historyData, setHistoryData] = useLocalStorage( + "magnetHistory", + initialHistory + ); + const [magnetParams, setMagnetParams] = useLocalStorage( + magnetId, + initialMagnetParams + ); + + async function magnet() { + const body = JSON.stringify({ ...magnetParams }); + const response = await fetch("/api/gradio/magnet", { + method: "POST", + body, + }); + + const result: Result = await response.json(); + setData(result); + setHistoryData((x) => [result, ...x]); + } + + const handleChange = ( + event: + | React.ChangeEvent + | React.ChangeEvent + | React.ChangeEvent + ) => { + const { name, value, type } = event.target; + setMagnetParams({ + ...magnetParams, + [name]: + type === "number" || type === "range" + ? Number(value) + : type === "checkbox" + ? (event.target as HTMLInputElement).checked // type assertion + : value, + }); + }; + + const favorite = async (_url: string, data?: Result) => { + const history_bundle_name_data = data?.history_bundle_name_data; + if (!history_bundle_name_data) return; + const response = await fetch("/api/gradio/bark_favorite", { + method: "POST", + body: JSON.stringify({ + history_bundle_name_data, + }), + }); + const result = await response.json(); + return result; + }; + + const useSeed = (_url: string, data?: Result) => { + const seed = data?.json.seed; + if (!seed) return; + setMagnetParams({ + ...magnetParams, + seed: Number(seed), + }); + }; + + const useParameters = (_url: string, data?: Result) => { + const params = data?.json; + if (!params) return; + setMagnetParams({ + ...magnetParams, + ...params, + seed: Number(params.seed), + model: params.model || initialMagnetParams.model, + decoding_steps_1: params.decoding_steps[0], + decoding_steps_2: params.decoding_steps[1], + decoding_steps_3: params.decoding_steps[2], + decoding_steps_4: params.decoding_steps[3], + }); + }; + + const funcs = { + favorite, + useSeed, + useParameters, + }; + + return ( + + ); +}; + +export default MagnetPage; + +const MagnetInputs = ({ + magnetParams, + handleChange, + setMagnetParams, + data, +}: { + magnetParams: MagnetParams; + handleChange: ( + event: + | React.ChangeEvent + | React.ChangeEvent + | React.ChangeEvent + ) => void; + setMagnetParams: React.Dispatch>; + data: Result | null; +}) => { + return ( +
+
+ +