import Head from "next/head"; import { useEffect, useState } from "react"; import { useInterval } from "usehooks-ts"; import ThreeCanvas from "../components/ThreeCanvas"; import PromptPanel from "../components/PromptPanel"; import Info from "../components/Info"; import Pause from "../components/Pause"; import { InferenceResult, PromptInput } from "../types"; import * as Tone from "tone"; const SERVER_URL = "http://129.146.52.68:3013/run_inference/"; const defaultPromptInputs = [ { prompt: "A jazz pianist playing a classical concerto" }, { prompt: "Country singer and a techno DJ" }, { prompt: "A typewriter in they style of K-Pop" }, { prompt: "lo-fi beat for the holidays" }, { prompt: "" }, { prompt: "" }, ]; enum AppState { SamePrompt, Transition, } const urlToBase64 = async (url: string) => { const data = await fetch(url); const blob = await data.blob(); return new Promise((resolve) => { const reader = new FileReader(); reader.readAsDataURL(blob); reader.onloadend = () => { const base64data = reader.result; resolve(base64data); }; }); }; // TODO(hayk): Do this as soon as sample comes back const timeout = 5000; const maxLength = 10; const alphaVelocity = 0.25; const maxNumInferenceResults = 15; export default function Home() { const [paused, setPaused] = useState(true); const [promptInputs, setPromptInputs] = useState(defaultPromptInputs); const [inferenceResults, setInferenceResults] = useState( [] ); // ///////////// const [tonePlayer, setTonePlayer] = useState(null); const [numClipsPlayed, setNumClipsPlayed] = useState(0); const [prevNumClipsPlayed, setPrevNumClipsPlayed] = useState(0); const [resultCounter, setResultCounter] = useState(0); const [alpha, setAlpha] = useState(0.0); const [seed, setSeed] = useState(0); const [appState, setAppState] = useState(AppState.SamePrompt); // On load, populate the first two prompts from checked-in URLs useEffect(() => { // NOTE(hayk): not currently populating initial prompts. if (true) { return; } if (inferenceResults.length > 0) { return; } const populateDefaults = async () => { const result1 = { input: { alpha: 0.0, start: defaultPromptInputs[0], end: defaultPromptInputs[1], }, image: (await urlToBase64("rap_sample.jpg")) as string, audio: (await urlToBase64("rap_sample.mp3")) as string, counter: 0, }; const result2 = { input: { alpha: 0.0, start: defaultPromptInputs[0], end: defaultPromptInputs[1], }, image: (await urlToBase64("pop_sample.jpg")) as string, audio: (await urlToBase64("pop_sample.mp3")) as string, counter: 1, }; console.log(result1); setInferenceResults([...inferenceResults, result1]); }; populateDefaults(); }, [inferenceResults]); // On load, create a player synced to the tone transport useEffect(() => { if (tonePlayer) { return; } if (inferenceResults.length === 0) { return; } const audioUrl = inferenceResults[0].audio; const player = new Tone.Player(audioUrl, () => { console.log("Created player."); player.loop = true; player.sync().start(0); // Set up a callback to increment numClipsPlayed at the edge of each clip const bufferLength = player.sampleTime * player.buffer.length; Tone.Transport.scheduleRepeat((time) => { console.log( "Edge of clip, t = ", Tone.Transport.getSecondsAtTime(time) ); setNumClipsPlayed((n) => n + 1); }, bufferLength); setTonePlayer(player); // Make further load callbacks do nothing. player.buffer.onload = () => {}; }).toDestination(); }, [tonePlayer, inferenceResults]); // On play/pause button, play/pause the audio with the tone transport useEffect(() => { if (!paused) { console.log("Play"); if (Tone.context.state == "suspended") { Tone.context.resume(); } if (tonePlayer) { Tone.Transport.start(); } } else { console.log("Pause"); if (tonePlayer) { Tone.Transport.pause(); } } }, [paused, tonePlayer]); useEffect(() => { if (numClipsPlayed == prevNumClipsPlayed) { return; } const maxResultCounter = Math.max( ...inferenceResults.map((r) => r.counter) ); if (maxResultCounter < resultCounter) { console.info( "not picking a new result, none available", resultCounter, maxResultCounter ); return; } const result = inferenceResults.find( (r: InferenceResult) => r.counter == resultCounter ); console.log("Incrementing result counter ", resultCounter); setResultCounter((c) => c + 1); tonePlayer.load(result.audio).then(() => { console.log("Loaded new audio"); // Re-jigger the transport so it stops playing old buffers. It seems like this doesn't // introduce a gap, but watch out for that. Tone.Transport.pause(); if (!paused) { Tone.Transport.start(); } }); setPrevNumClipsPlayed(numClipsPlayed); }, [ numClipsPlayed, prevNumClipsPlayed, resultCounter, inferenceResults, paused, tonePlayer, ]); // ///////////// // Set the app state based on the prompt inputs array useEffect(() => { if (alpha <= 1) { return; } const upNextPrompt = promptInputs[promptInputs.length - 1].prompt; const endPrompt = promptInputs[promptInputs.length - 2].prompt; if (appState == AppState.SamePrompt) { if (endPrompt) { setAppState(AppState.Transition); } setSeed(seed + 1); } else if (appState == AppState.Transition) { setPromptInputs([...promptInputs, { prompt: "" }]); if (upNextPrompt) { setAppState(AppState.Transition); } else { setAppState(AppState.SamePrompt); } } setAlpha(alpha - 1); }, [promptInputs, alpha]); // On any app state change, reset alpha useEffect(() => { console.log("App State: ", appState); setAlpha(0.25); }, [appState]); const runInference = async ( alpha: number, seed: number, appState: AppState, promptInputs: PromptInput[] ) => { const startPrompt = promptInputs[promptInputs.length - 3].prompt; const endPrompt = promptInputs[promptInputs.length - 2].prompt; const transitioning = appState == AppState.Transition; const inferenceInput = { alpha: alpha, start: { prompt: startPrompt, seed: seed, }, end: { prompt: transitioning ? endPrompt : startPrompt, seed: transitioning ? seed : seed + 1, }, }; console.log("Running for input: ", inferenceInput); const response = await fetch(SERVER_URL, { method: "POST", headers: { "Content-Type": "application/json", "Access-Control-Allow-Origin": "*", }, body: JSON.stringify(inferenceInput), }); const data = await response.json(); setInferenceResults((prevResults) => { const maxResultCounter = Math.max(...prevResults.map((r) => r.counter)); const lastResult = prevResults.find((r) => r.counter == maxResultCounter); const newCounter = lastResult ? lastResult.counter + 1 : 0; const newResult = { input: inferenceInput, // TODO(hayk): Swap for JPG? image: "data:image/png;base64," + data.image, audio: "data:audio/mpeg;base64," + data.audio, counter: newCounter, }; // TODO(hayk): Fix up // if (alpha > 1.0) { // setAlpha(alpha - 0.75); // setSeed(seed + 1); // } else { // setAlpha(inferenceInput.alpha + alphaVelocity); // } setAlpha(alpha + alphaVelocity); let results = [...prevResults, newResult]; // TODO(hayk): Move this somewhere more reasonable to prune. if (results.length > maxLength) { results = results.slice(1); } return results; }); }; useInterval(() => { console.log(inferenceResults); if (inferenceResults.length < maxNumInferenceResults) { runInference(alpha, seed, appState, promptInputs); } }, timeout); // Run inference on a timer. // TODO(hayk): Improve the timing here. // TODO(hayk): Fix warning about effects. useEffect(() => { runInference(alpha, seed, appState, promptInputs); }, []); return ( <> Riffusion
Tone.Transport.seconds} audioLength={ tonePlayer ? tonePlayer.sampleTime * tonePlayer.buffer.length : 5 } inferenceResults={inferenceResults} />
{ setPromptInputs([...promptInputs, { prompt: prompt }]); }} changePrompt={(prompt: string, index: number) => { const newPromptInputs = [...promptInputs]; newPromptInputs[index].prompt = prompt; setPromptInputs(newPromptInputs); }} />
); }