riffusion-app/components/ModelInference.tsx

211 lines
5.6 KiB
TypeScript

import { useRouter } from "next/router";
import { useCallback, useEffect, useState } from "react";
import {
AppState,
InferenceInput,
InferenceResult,
PromptInput,
} from "../types";
interface ModelInferenceProps {
alpha: number;
seed: number;
appState: AppState;
promptInputs: PromptInput[];
nowPlayingResult: InferenceResult;
newResultCallback: (input: InferenceInput, result: InferenceResult) => void;
useBaseten: boolean;
}
/**
* Calls the server to run model inference.
*
*
*/
export default function ModelInference({
alpha,
seed,
appState,
promptInputs,
nowPlayingResult,
newResultCallback,
useBaseten,
}: ModelInferenceProps) {
// Create parameters for the inference request
const [denoising, setDenoising] = useState(0.75);
const [guidance, setGuidance] = useState(7.0);
const [numInferenceSteps, setNumInferenceSteps] = useState(50);
const [seedImageId, setSeedImageId] = useState("og_beat");
const [maskImageId, setMaskImageId] = useState(null);
const [initializedUrlParams, setInitializedUrlParams] = useState(false);
const [numRequestsMade, setNumRequestsMade] = useState(0);
const [numResponsesReceived, setNumResponsesReceived] = useState(0);
// Set initial params from URL query strings
const router = useRouter();
useEffect(() => {
if (router.query.denoising) {
setDenoising(parseFloat(router.query.denoising as string));
}
if (router.query.guidance) {
setGuidance(parseFloat(router.query.guidance as string));
}
if (router.query.numInferenceSteps) {
setNumInferenceSteps(parseInt(router.query.numInferenceSteps as string));
}
if (router.query.seedImageId) {
setSeedImageId(router.query.seedImageId as string);
}
if (router.query.maskImageId) {
if (router.query.maskImageId === "none") {
setMaskImageId("");
} else {
setMaskImageId(router.query.maskImageId as string);
}
}
setInitializedUrlParams(true);
}, [router.query]);
// Memoized function to kick off an inference request
const runInference = useCallback(
async (
alpha: number,
seed: number,
appState: AppState,
promptInputs: PromptInput[]
) => {
const startPrompt = promptInputs[promptInputs.length - 3].prompt;
const endPrompt = promptInputs[promptInputs.length - 2].prompt;
const transitioning = appState == AppState.TRANSITION;
const inferenceInput = {
alpha: alpha,
num_inference_steps: numInferenceSteps,
seed_image_id: seedImageId,
mask_image_id: maskImageId,
start: {
prompt: startPrompt,
seed: seed,
denoising: denoising,
guidance: guidance,
},
end: {
prompt: transitioning ? endPrompt : startPrompt,
seed: transitioning ? seed : seed + 1,
denoising: denoising,
guidance: guidance,
},
};
console.log(`Inference #${numRequestsMade}: `, {
alpha: alpha,
prompt_a: inferenceInput.start.prompt,
seed_a: inferenceInput.start.seed,
prompt_b: inferenceInput.end.prompt,
seed_b: inferenceInput.end.seed,
appState: appState,
});
setNumRequestsMade((n) => n + 1);
// Customize for baseten
const apiHandler = useBaseten ? "/api/baseten" : "/api/server";
const payload = useBaseten
? { worklet_input: inferenceInput }
: inferenceInput;
const response = await fetch(apiHandler, {
method: "POST",
body: JSON.stringify(payload),
});
const data = await response.json();
console.log(`Got result #${numResponsesReceived}`);
if (useBaseten) {
if (data?.worklet_output?.model_output) {
newResultCallback(
inferenceInput,
JSON.parse(data.worklet_output.model_output)
);
}
// Note, data is currently wrapped in a data field
else if (data?.data?.worklet_output?.model_output) {
newResultCallback(
inferenceInput,
JSON.parse(data.data.worklet_output.model_output)
);
}
else {
console.error("Baseten call failed: ", data);
}
} else {
// Note, data is currently wrapped in a data field
newResultCallback(inferenceInput, data.data);
}
setNumResponsesReceived((n) => n + 1);
},
[
denoising,
guidance,
maskImageId,
numInferenceSteps,
seedImageId,
newResultCallback,
numRequestsMade,
numResponsesReceived,
useBaseten,
]
);
// Kick off inference requests
useEffect(() => {
// Make sure things are initialized properly
if (
!initializedUrlParams ||
appState == AppState.UNINITIALIZED ||
promptInputs.length == 0
) {
return;
}
if (numRequestsMade == 0) {
// Kick off the first request
runInference(alpha, seed, appState, promptInputs);
} else if (numRequestsMade == numResponsesReceived) {
// Otherwise buffer ahead a few from where the audio player currently is
// TODO(hayk): Replace this with better buffer management
const nowPlayingCounter = nowPlayingResult ? nowPlayingResult.counter : 0;
const numAhead = numRequestsMade - nowPlayingCounter;
if (numAhead < 3) {
runInference(alpha, seed, appState, promptInputs);
}
}
}, [
initializedUrlParams,
alpha,
seed,
appState,
promptInputs,
nowPlayingResult,
numRequestsMade,
numResponsesReceived,
runInference,
]);
return null;
}