import { useRouter } from "next/router"; import { useCallback, useEffect, useState } from "react"; import { AppState, InferenceInput, InferenceResult, PromptInput, } from "../types"; interface ModelInferenceProps { alpha: number; seed: number; appState: AppState; promptInputs: PromptInput[]; nowPlayingResult: InferenceResult; newResultCallback: (input: InferenceInput, result: InferenceResult) => void; useBaseten: boolean; } /** * Calls the server to run model inference. * * */ export default function ModelInference({ alpha, seed, appState, promptInputs, nowPlayingResult, newResultCallback, useBaseten, }: ModelInferenceProps) { // Create parameters for the inference request const [denoising, setDenoising] = useState(0.75); const [guidance, setGuidance] = useState(7.0); const [numInferenceSteps, setNumInferenceSteps] = useState(50); const [seedImageId, setSeedImageId] = useState("og_beat"); const [maskImageId, setMaskImageId] = useState(null); const [initializedUrlParams, setInitializedUrlParams] = useState(false); const [numRequestsMade, setNumRequestsMade] = useState(0); const [numResponsesReceived, setNumResponsesReceived] = useState(0); // Set initial params from URL query strings const router = useRouter(); useEffect(() => { if (router.query.denoising) { setDenoising(parseFloat(router.query.denoising as string)); } if (router.query.guidance) { setGuidance(parseFloat(router.query.guidance as string)); } if (router.query.numInferenceSteps) { setNumInferenceSteps(parseInt(router.query.numInferenceSteps as string)); } if (router.query.seedImageId) { setSeedImageId(router.query.seedImageId as string); } if (router.query.maskImageId) { if (router.query.maskImageId === "none") { setMaskImageId(""); } else { setMaskImageId(router.query.maskImageId as string); } } setInitializedUrlParams(true); }, [router.query]); // Memoized function to kick off an inference request const runInference = useCallback( async ( alpha: number, seed: number, appState: AppState, promptInputs: PromptInput[] ) => { const startPrompt = promptInputs[promptInputs.length - 3].prompt; const endPrompt = promptInputs[promptInputs.length - 2].prompt; const transitioning = appState == AppState.TRANSITION; const inferenceInput = { alpha: alpha, num_inference_steps: numInferenceSteps, seed_image_id: seedImageId, mask_image_id: maskImageId, start: { prompt: startPrompt, seed: seed, denoising: denoising, guidance: guidance, }, end: { prompt: transitioning ? endPrompt : startPrompt, seed: transitioning ? seed : seed + 1, denoising: denoising, guidance: guidance, }, }; console.log(`Inference #${numRequestsMade}: `, { alpha: alpha, prompt_a: inferenceInput.start.prompt, seed_a: inferenceInput.start.seed, prompt_b: inferenceInput.end.prompt, seed_b: inferenceInput.end.seed, appState: appState, }); setNumRequestsMade((n) => n + 1); // Customize for baseten const apiHandler = useBaseten ? "/api/baseten" : "/api/server"; const payload = useBaseten ? { worklet_input: inferenceInput } : inferenceInput; const response = await fetch(apiHandler, { method: "POST", body: JSON.stringify(payload), }); const data = await response.json(); console.log(`Got result #${numResponsesReceived}`); if (useBaseten) { if (data?.worklet_output?.model_output) { newResultCallback( inferenceInput, JSON.parse(data.worklet_output.model_output) ); } // Note, data is currently wrapped in a data field else if (data?.data?.worklet_output?.model_output) { newResultCallback( inferenceInput, JSON.parse(data.data.worklet_output.model_output) ); } else { console.error("Baseten call failed: ", data); } } else { // Note, data is currently wrapped in a data field newResultCallback(inferenceInput, data.data); } setNumResponsesReceived((n) => n + 1); }, [ denoising, guidance, maskImageId, numInferenceSteps, seedImageId, newResultCallback, numRequestsMade, numResponsesReceived, useBaseten, ] ); // Kick off inference requests useEffect(() => { // Make sure things are initialized properly if ( !initializedUrlParams || appState == AppState.UNINITIALIZED || promptInputs.length == 0 ) { return; } if (numRequestsMade == 0) { // Kick off the first request runInference(alpha, seed, appState, promptInputs); } else if (numRequestsMade == numResponsesReceived) { // Otherwise buffer ahead a few from where the audio player currently is // TODO(hayk): Replace this with better buffer management const nowPlayingCounter = nowPlayingResult ? nowPlayingResult.counter : 0; const numAhead = numRequestsMade - nowPlayingCounter; if (numAhead < 3) { runInference(alpha, seed, appState, promptInputs); } } }, [ initializedUrlParams, alpha, seed, appState, promptInputs, nowPlayingResult, numRequestsMade, numResponsesReceived, runInference, ]); return null; }