Skip to content

Commit

Permalink
add batching to magnet (#283)
Browse files Browse the repository at this point in the history
* add batching to magnet

* readme
  • Loading branch information
rsxdalv authored Mar 10, 2024
1 parent 4a95eb6 commit d04bafe
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 154 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ https://rsxdalv.github.io/bark-speaker-directory/
https://github.com/rsxdalv/tts-generation-webui/discussions/186#discussioncomment-7291274

## Changelog
Mar 10:
* Add Batching to React UI Magnet (#283)

Mar 5:
* Add Batching to React UI MusicGen (#281), thanks to https://github.com/Aamir3d for requesting this and providing feedback

Expand Down
64 changes: 64 additions & 0 deletions react-ui/src/components/HyperParameters.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import React from "react";
import { Progress } from "./Progress";

export const HyperParameters = <
T extends {
iterations: number;
splitByLines: boolean;
}
>({
params,
setParams,
progress,
progressMax,
isInterrupted: interrupted,
interrupt,
}: {
params: T;
setParams: React.Dispatch<React.SetStateAction<T>>;
progress: number;
progressMax: number;
isInterrupted: boolean;
interrupt: () => void;
}) => (
<div className="flex flex-col gap-y-2 border border-gray-300 p-2 rounded">
<label className="text-sm">Hyperparameters:</label>
<div className="flex gap-x-2 items-center">
<label className="text-sm">Iterations:</label>
<input
type="number"
name="iterations"
value={params.iterations}
onChange={(event) => {
setParams({
...params,
iterations: Number(event.target.value),
});
}}
className="border border-gray-300 p-2 rounded"
min="1"
max="10"
step="1"
/>
</div>
<div className="flex gap-x-2 items-center">
<div className="text-sm">Each line as a separate prompt:</div>
<input
type="checkbox"
name="splitByLines"
checked={params.splitByLines}
onChange={(event) => {
setParams({
...params,
splitByLines: event.target.checked,
});
}}
className="border border-gray-300 p-2 rounded"
/>
</div>
<Progress progress={progress} progressMax={progressMax} />
<button className="border border-gray-300 p-2 rounded" onClick={interrupt}>
{interrupted ? "Interrupted..." : "Interrupt"}
</button>
</div>
);
37 changes: 37 additions & 0 deletions react-ui/src/components/Progress.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import React from "react";

export const Progress = ({
progress,
progressMax,
}: {
progress: number;
progressMax: number;
}) => (
<div className="flex gap-x-2 items-center">
<label className="text-sm">Progress:</label>
<progress
value={progress}
max={progressMax}
className="[&::-webkit-progress-bar]:rounded [&::-webkit-progress-value]:rounded [&::-webkit-progress-bar]:bg-slate-300 [&::-webkit-progress-value]:bg-orange-400 [&::-moz-progress-bar]:bg-orange-400 [&::-webkit-progress-value]:transition-all [&::-webkit-progress-value]:duration-200"
/>
{progress}/{progressMax}
</div>
);

export const manageProgress = async (
callback: (args: { incrementProgress: () => void }) => Promise<void>,
max: number,
setProgress: React.Dispatch<
React.SetStateAction<{ current: number; max: number }>
>
) => {
setProgress({ current: 0, max });
await callback({
incrementProgress: () =>
setProgress(({ current, max }) => ({
current: current + 1,
max,
})),
});
setProgress({ current: 0, max: 0 });
};
15 changes: 15 additions & 0 deletions react-ui/src/data/hyperParamsUtils.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export const extractTexts = (
text: string,
params: { iterations?: number; splitByLines: any }
) => (params.splitByLines ? text.split("\n") : [text]);

export const incrementNonRandomSeed = (seed: number, iteration: number) =>
seed === -1 ? -1 : seed + iteration;

export const getMax = (texts: string[], iterations: number) =>
texts.length * iterations;

export const initialHyperParams = {
iterations: 1,
splitByLines: false,
};
15 changes: 15 additions & 0 deletions react-ui/src/hooks/useInterrupt.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { useRef } from "react";

export const useInterrupt = () => {
const interrupted = useRef(false);

const resetInterrupt = (callback: () => Promise<void>) => async () => {
interrupted.current = false;
await callback();
interrupted.current = false;
};

const interrupt = () => (interrupted.current = true);

return { interrupted, resetInterrupt, interrupt };
};
130 changes: 107 additions & 23 deletions react-ui/src/pages/magnet.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ import {
magnetId,
} from "../tabs/MagnetParams";
import { GradioFile } from "../types/GradioFile";
import { HyperParameters } from "../components/HyperParameters";
import {
extractTexts,
getMax,
incrementNonRandomSeed,
initialHyperParams,
} from "../data/hyperParamsUtils";
import { useInterrupt } from "../hooks/useInterrupt";
import { manageProgress } from "../components/Progress";

type AudioOutput = {
name: string;
Expand Down Expand Up @@ -343,19 +352,64 @@ const MagnetPage = () => {
magnetId,
initialMagnetParams
);
const [hyperParams, setHyperParams] = useLocalStorage<
typeof initialHyperParams
>("magnetHyperParams", initialHyperParams);
const [showLast, setShowLast] = useLocalStorage<number>(
"magnetShowLast",
10
);

async function magnet() {
const body = JSON.stringify({ ...magnetParams });
const response = await fetch("/api/gradio/magnet", {
method: "POST",
body,
});
const { interrupted, resetInterrupt, interrupt } = useInterrupt();
const [progress, setProgress] = React.useState({ current: 0, max: 0 });

function magnetWithProgress() {
const texts = extractTexts(magnetParams.text, hyperParams);
const { iterations } = hyperParams;

return manageProgress(
({ incrementProgress }) =>
magnetConsumer(
magnetGenerator(texts, iterations, magnetParams),
incrementProgress
),
getMax(texts, iterations),
setProgress
);
}

const result: Result = await response.json();
setData(result);
setHistoryData((x) => [result, ...x]);
async function* magnetGenerator(
texts: string[],
iterations: number,
magnetParams: MagnetParams
) {
for (let iteration = 0; iteration < iterations; iteration++) {
for (const text of texts) {
if (interrupted.current) {
return;
}
yield magnetGenerate({
...magnetParams,
text,
seed: incrementNonRandomSeed(magnetParams.seed, iteration),
});
}
}
}

async function magnetConsumer(
generator: AsyncGenerator<Result, void, unknown>,
callback: (result: Result) => void
) {
for await (const result of generator) {
setData(result);
setHistoryData((x) => [result, ...x]);
callback(result);
}
}

const magnet = resetInterrupt(magnetWithProgress);

const handleChange = (
event:
| React.ChangeEvent<HTMLInputElement>
Expand Down Expand Up @@ -416,21 +470,30 @@ const MagnetPage = () => {
useSeed,
useParameters,
};

const clearHistory = () => setHistoryData([]);
return (
<Template>
<Head>
<title>Magnet - TTS Generation Webui</title>
</Head>
<div className="p-4 flex w-full flex-col">
<div className="gap-y-4 p-4 flex w-full flex-col">
<MagnetInputs
magnetParams={magnetParams}
handleChange={handleChange}
setMagnetParams={setMagnetParams}
data={data}
/>

<div className="my-4 flex flex-col gap-y-2">
<HyperParameters
params={hyperParams}
setParams={setHyperParams}
interrupt={interrupt}
isInterrupted={interrupted.current}
progress={progress.current}
progressMax={progress.max}
/>

<div className="flex flex-col gap-y-2">
<button
className="border border-gray-300 p-2 rounded"
onClick={magnet}
Expand All @@ -448,27 +511,38 @@ const MagnetPage = () => {

<div className="flex flex-col gap-y-2 border border-gray-300 p-2 rounded">
<label className="text-sm">History:</label>
{/* Clear history */}
<button
className="border border-gray-300 p-2 rounded"
onClick={() => {
setHistoryData([]);
}}
>
Clear History
</button>
<div className="flex gap-x-2 items-center">
<button
className="border border-gray-300 p-2 px-40 rounded"
onClick={clearHistory}
>
Clear History
</button>
<div className="flex gap-x-2 items-center">
<label className="text-sm">Show Last X entries:</label>
<input
type="number"
value={showLast}
onChange={(event) => setShowLast(Number(event.target.value))}
className="border border-gray-300 p-2 rounded"
min="0"
max="100"
step="1"
/>
</div>
</div>
<div className="flex flex-col gap-y-2">
{historyData &&
historyData
.slice(1, 6)
.slice(1, showLast + 1)
.map((item, index) => (
<AudioOutput
key={index}
audioOutput={item.audio}
metadata={item}
label={item.history_bundle_name_data}
funcs={funcs}
filter={["sendToMagnet"]}
filter={["sendToMusicgen"]}
/>
))}
</div>
Expand Down Expand Up @@ -677,3 +751,13 @@ const MagnetInputs = ({
</div>
);
};

async function magnetGenerate(magnetParams: MagnetParams) {
const body = JSON.stringify({ ...magnetParams });
const response = await fetch("/api/gradio/magnet", {
method: "POST",
body,
});

return (await response.json()) as Result;
}
Loading

0 comments on commit d04bafe

Please sign in to comment.