riffusion-app/pages/index.tsx

396 lines
10 KiB
TypeScript
Raw Normal View History

import Head from "next/head";
2022-11-24 22:24:37 -07:00
import { useEffect, useState } from "react";
import { useInterval } from "usehooks-ts";
2022-11-24 21:22:03 -07:00
import ThreeCanvas from "../components/ThreeCanvas";
2022-11-23 22:47:18 -07:00
import PromptPanel from "../components/PromptPanel";
2022-11-24 17:54:00 -07:00
import Info from "../components/Info";
2022-11-23 23:08:00 -07:00
import Pause from "../components/Pause";
2022-11-23 23:08:00 -07:00
import { InferenceResult, PromptInput } from "../types";
import * as Tone from "tone";
2022-11-24 21:22:03 -07:00
const SERVER_URL = "http://129.146.52.68:3013/run_inference/";
2022-11-23 23:08:00 -07:00
const defaultPromptInputs = [
2022-11-24 21:22:03 -07:00
{ prompt: "A jazz pianist playing a classical concerto" },
{ prompt: "Country singer and a techno DJ" },
2022-11-25 22:17:29 -07:00
{ prompt: "A typewriter in the style of K-Pop" },
2022-11-24 23:58:52 -07:00
{ prompt: "lo-fi beat for the holidays" },
2022-11-24 21:22:03 -07:00
{ prompt: "" },
{ prompt: "" },
2022-11-23 23:08:00 -07:00
];
2022-11-23 22:46:32 -07:00
2022-11-24 23:58:33 -07:00
enum AppState {
SamePrompt,
Transition,
}
2022-11-24 21:22:03 -07:00
const urlToBase64 = async (url: string) => {
const data = await fetch(url);
const blob = await data.blob();
return new Promise((resolve) => {
const reader = new FileReader();
reader.readAsDataURL(blob);
reader.onloadend = () => {
const base64data = reader.result;
resolve(base64data);
};
});
};
2022-11-25 00:26:44 -07:00
function getRandomInt(max: number) {
return Math.floor(Math.random() * max);
}
// TODO(hayk): Do this as soon as sample comes back
2022-11-24 23:58:33 -07:00
const timeout = 5000;
const maxLength = 10;
2022-11-24 22:24:37 -07:00
const alphaVelocity = 0.25;
2022-11-24 23:58:33 -07:00
const maxNumInferenceResults = 15;
2022-11-23 23:08:00 -07:00
export default function Home() {
const [paused, setPaused] = useState(true);
2022-11-23 22:46:32 -07:00
2022-11-23 23:08:00 -07:00
const [promptInputs, setPromptInputs] =
useState<PromptInput[]>(defaultPromptInputs);
const [inferenceResults, setInferenceResults] = useState<InferenceResult[]>(
2022-11-24 22:24:37 -07:00
[]
);
// /////////////
const [tonePlayer, setTonePlayer] = useState<Tone.Player>(null);
const [numClipsPlayed, setNumClipsPlayed] = useState(0);
const [prevNumClipsPlayed, setPrevNumClipsPlayed] = useState(0);
const [resultCounter, setResultCounter] = useState(0);
2022-11-24 22:24:37 -07:00
const [alpha, setAlpha] = useState(0.0);
2022-11-25 00:26:44 -07:00
const [seed, setSeed] = useState(getRandomInt(1000000));
2022-11-24 21:22:03 -07:00
2022-11-24 23:58:33 -07:00
const [appState, setAppState] = useState<AppState>(AppState.SamePrompt);
2022-11-24 21:22:03 -07:00
// On load, populate the first two prompts from checked-in URLs
useEffect(() => {
2022-11-24 22:24:37 -07:00
// NOTE(hayk): not currently populating initial prompts.
if (true) {
return;
}
2022-11-24 21:22:03 -07:00
if (inferenceResults.length > 0) {
return;
}
const populateDefaults = async () => {
const result1 = {
input: {
alpha: 0.0,
start: defaultPromptInputs[0],
end: defaultPromptInputs[1],
},
2022-11-24 22:24:37 -07:00
image: (await urlToBase64("rap_sample.jpg")) as string,
audio: (await urlToBase64("rap_sample.mp3")) as string,
2022-11-24 21:22:03 -07:00
counter: 0,
};
const result2 = {
input: {
alpha: 0.0,
start: defaultPromptInputs[0],
end: defaultPromptInputs[1],
},
2022-11-24 22:24:37 -07:00
image: (await urlToBase64("pop_sample.jpg")) as string,
audio: (await urlToBase64("pop_sample.mp3")) as string,
2022-11-24 21:22:03 -07:00
counter: 1,
};
2022-11-24 22:24:37 -07:00
console.log(result1);
setInferenceResults([...inferenceResults, result1]);
2022-11-24 21:22:03 -07:00
};
populateDefaults();
}, [inferenceResults]);
// On load, create a player synced to the tone transport
useEffect(() => {
2022-11-24 21:22:03 -07:00
if (tonePlayer) {
return;
}
if (inferenceResults.length === 0) {
return;
}
const audioUrl = inferenceResults[0].audio;
const player = new Tone.Player(audioUrl, () => {
console.log("Created player.");
player.loop = true;
player.sync().start(0);
// Set up a callback to increment numClipsPlayed at the edge of each clip
const bufferLength = player.sampleTime * player.buffer.length;
Tone.Transport.scheduleRepeat((time) => {
console.log(
"Edge of clip, t = ",
2022-11-25 00:26:44 -07:00
Tone.Transport.getSecondsAtTime(time),
bufferLength
);
setNumClipsPlayed((n) => n + 1);
}, bufferLength);
setTonePlayer(player);
// Make further load callbacks do nothing.
player.buffer.onload = () => {};
}).toDestination();
2022-11-24 21:22:03 -07:00
}, [tonePlayer, inferenceResults]);
// On play/pause button, play/pause the audio with the tone transport
useEffect(() => {
if (!paused) {
console.log("Play");
if (Tone.context.state == "suspended") {
Tone.context.resume();
}
if (tonePlayer) {
Tone.Transport.start();
}
} else {
console.log("Pause");
if (tonePlayer) {
Tone.Transport.pause();
}
}
2022-11-24 15:56:01 -07:00
}, [paused, tonePlayer]);
useEffect(() => {
if (numClipsPlayed == prevNumClipsPlayed) {
return;
}
const maxResultCounter = Math.max(
...inferenceResults.map((r) => r.counter)
);
2022-11-24 21:22:03 -07:00
if (maxResultCounter < resultCounter) {
2022-11-24 22:24:37 -07:00
console.info(
"not picking a new result, none available",
resultCounter,
maxResultCounter
);
return;
}
const result = inferenceResults.find(
(r: InferenceResult) => r.counter == resultCounter
);
2022-11-24 22:24:37 -07:00
console.log("Incrementing result counter ", resultCounter);
setResultCounter((c) => c + 1);
tonePlayer.load(result.audio).then(() => {
2022-11-24 21:22:03 -07:00
console.log("Loaded new audio");
// Re-jigger the transport so it stops playing old buffers. It seems like this doesn't
// introduce a gap, but watch out for that.
Tone.Transport.pause();
if (!paused) {
Tone.Transport.start();
}
});
setPrevNumClipsPlayed(numClipsPlayed);
}, [
numClipsPlayed,
prevNumClipsPlayed,
resultCounter,
inferenceResults,
paused,
tonePlayer,
]);
// /////////////
2022-11-24 23:58:33 -07:00
// Set the app state based on the prompt inputs array
useEffect(() => {
if (alpha <= 1) {
return;
}
const upNextPrompt = promptInputs[promptInputs.length - 1].prompt;
const endPrompt = promptInputs[promptInputs.length - 2].prompt;
if (appState == AppState.SamePrompt) {
if (endPrompt) {
setAppState(AppState.Transition);
}
setSeed(seed + 1);
} else if (appState == AppState.Transition) {
setPromptInputs([...promptInputs, { prompt: "" }]);
if (upNextPrompt) {
setAppState(AppState.Transition);
} else {
setAppState(AppState.SamePrompt);
}
}
setAlpha(alpha - 1);
}, [promptInputs, alpha]);
// On any app state change, reset alpha
useEffect(() => {
console.log("App State: ", appState);
setAlpha(0.25);
}, [appState]);
const runInference = async (
alpha: number,
seed: number,
appState: AppState,
promptInputs: PromptInput[]
) => {
2022-11-24 22:24:37 -07:00
const startPrompt = promptInputs[promptInputs.length - 3].prompt;
const endPrompt = promptInputs[promptInputs.length - 2].prompt;
2022-11-24 23:58:33 -07:00
const transitioning = appState == AppState.Transition;
2022-11-25 22:17:29 -07:00
const denoising = 0.85;
const guidance = 7.0;
const numInferenceSteps = 50;
const seedImageId = 0;
const maskImageId = null;
2022-11-24 22:24:37 -07:00
const inferenceInput = {
alpha: alpha,
2022-11-25 22:17:29 -07:00
num_inference_steps: numInferenceSteps,
seed_image_id: seedImageId,
mask_image_id: maskImageId,
2022-11-24 22:24:37 -07:00
start: {
prompt: startPrompt,
seed: seed,
2022-11-25 22:17:29 -07:00
denoising: denoising,
guidance: guidance,
2022-11-24 22:24:37 -07:00
},
end: {
2022-11-24 23:58:33 -07:00
prompt: transitioning ? endPrompt : startPrompt,
seed: transitioning ? seed : seed + 1,
2022-11-25 22:17:29 -07:00
denoising: denoising,
guidance: guidance,
2022-11-24 22:24:37 -07:00
},
};
console.log("Running for input: ", inferenceInput);
const response = await fetch(SERVER_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
body: JSON.stringify(inferenceInput),
});
const data = await response.json();
setInferenceResults((prevResults) => {
const maxResultCounter = Math.max(...prevResults.map((r) => r.counter));
const lastResult = prevResults.find((r) => r.counter == maxResultCounter);
const newCounter = lastResult ? lastResult.counter + 1 : 0;
const newResult = {
input: inferenceInput,
// TODO(hayk): Swap for JPG?
image: "data:image/png;base64," + data.image,
audio: "data:audio/mpeg;base64," + data.audio,
counter: newCounter,
};
// TODO(hayk): Fix up
2022-11-24 23:58:33 -07:00
// if (alpha > 1.0) {
// setAlpha(alpha - 0.75);
// setSeed(seed + 1);
// } else {
// setAlpha(inferenceInput.alpha + alphaVelocity);
// }
setAlpha(alpha + alphaVelocity);
2022-11-24 22:24:37 -07:00
let results = [...prevResults, newResult];
// TODO(hayk): Move this somewhere more reasonable to prune.
if (results.length > maxLength) {
results = results.slice(1);
}
return results;
});
};
2022-11-25 22:17:29 -07:00
// Run inference on a timer.
// TODO(hayk): Improve the strategy here.
2022-11-24 22:24:37 -07:00
useInterval(() => {
2022-11-24 23:58:33 -07:00
console.log(inferenceResults);
if (inferenceResults.length < maxNumInferenceResults) {
runInference(alpha, seed, appState, promptInputs);
}
2022-11-24 22:24:37 -07:00
}, timeout);
2022-11-24 23:58:33 -07:00
// TODO(hayk): Fix warning about effects.
2022-11-24 22:24:37 -07:00
useEffect(() => {
2022-11-24 23:58:33 -07:00
runInference(alpha, seed, appState, promptInputs);
2022-11-24 22:24:37 -07:00
}, []);
return (
2022-11-20 14:39:15 -07:00
<>
<Head>
<title>Riffusion</title>
<meta
name="description"
content="My name is Riffusion, and I write music."
/>
<link rel="icon" href="/favicon.ico" />
</Head>
2022-11-24 12:41:41 -07:00
<div className="bg-[#0A2342] flex flex-row min-h-screen text-white">
2022-11-20 14:39:15 -07:00
<div className="w-1/3 min-h-screen">
2022-11-25 22:17:29 -07:00
<ThreeCanvas
paused={paused}
getTime={() => Tone.Transport.seconds}
audioLength={
tonePlayer ? tonePlayer.sampleTime * tonePlayer.buffer.length : 0
}
inferenceResults={inferenceResults}
/>
2022-11-20 14:39:15 -07:00
</div>
2022-11-23 23:08:00 -07:00
<PromptPanel
prompts={promptInputs}
2022-11-24 00:23:34 -07:00
addPrompt={(prompt: string) => {
setPromptInputs([...promptInputs, { prompt: prompt }]);
2022-11-23 23:08:00 -07:00
}}
2022-11-24 19:32:39 -07:00
changePrompt={(prompt: string, index: number) => {
2022-11-24 16:23:20 -07:00
const newPromptInputs = [...promptInputs];
2022-11-24 19:32:39 -07:00
newPromptInputs[index].prompt = prompt;
2022-11-24 16:23:20 -07:00
setPromptInputs(newPromptInputs);
}}
2022-11-23 23:08:00 -07:00
/>
2022-11-20 14:56:58 -07:00
2022-11-24 17:54:00 -07:00
<Info />
2022-11-20 15:02:14 -07:00
<Pause paused={paused} setPaused={setPaused} />
2022-11-20 14:39:15 -07:00
</div>
</>
);
}