Fix alpha > 1 issue
This commit is contained in:
parent
d2aea7f2e3
commit
37bdf4fe61
|
@ -10,6 +10,7 @@ import {
|
|||
|
||||
interface ModelInferenceProps {
|
||||
alpha: number;
|
||||
alphaRollover: boolean;
|
||||
seed: number;
|
||||
appState: AppState;
|
||||
promptInputs: PromptInput[];
|
||||
|
@ -27,6 +28,7 @@ interface ModelInferenceProps {
|
|||
*/
|
||||
export default function ModelInference({
|
||||
alpha,
|
||||
alphaRollover,
|
||||
seed,
|
||||
appState,
|
||||
promptInputs,
|
||||
|
@ -176,6 +178,11 @@ export default function ModelInference({
|
|||
return;
|
||||
}
|
||||
|
||||
// Wait for alpha rollover to resolve.
|
||||
if (alphaRollover) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (numRequestsMade == 0) {
|
||||
// Kick off the first request
|
||||
runInference(alpha, seed, appState, promptInputs);
|
||||
|
@ -193,6 +200,7 @@ export default function ModelInference({
|
|||
}, [
|
||||
initializedUrlParams,
|
||||
alpha,
|
||||
alphaRollover,
|
||||
seed,
|
||||
appState,
|
||||
promptInputs,
|
||||
|
|
|
@ -395,15 +395,15 @@ export default function Home() {
|
|||
</p>
|
||||
<p className="mt-3">
|
||||
The app communicates over an API to run the inference calls on a GPU
|
||||
server. We used{" "}
|
||||
<a href="https://truss.baseten.co">Truss</a>{" "}
|
||||
to package the model and test it locally before
|
||||
deploying it to Baseten which provided GPU-backed inference, auto-scaling,
|
||||
and observability. We used NVIDIA A10Gs in production.
|
||||
server. We used <a href="https://truss.baseten.co">Truss</a> to
|
||||
package the model and test it locally before deploying it to Baseten
|
||||
which provided GPU-backed inference, auto-scaling, and
|
||||
observability. We used NVIDIA A10Gs in production.
|
||||
</p>
|
||||
<p className="mt-3">
|
||||
If you have a GPU powerful enough to generate stable diffusion
|
||||
results in under five seconds, you can run the experience locally.
|
||||
results in under five seconds, you can run the experience locally
|
||||
using our test flask server.
|
||||
</p>
|
||||
<br />
|
||||
<h2 className="pt-10 pb-5 text-3xl font-bold">Code</h2>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { useRouter } from "next/router";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import * as Tone from "tone";
|
||||
|
||||
import AudioPlayer from "../components/AudioPlayer";
|
||||
|
@ -41,6 +41,7 @@ export default function Home() {
|
|||
|
||||
// Current interpolation parameters
|
||||
const [alpha, setAlpha] = useState(0.0);
|
||||
const [alphaRollover, setAlphaRollover] = useState(false);
|
||||
const [alphaVelocity, setAlphaVelocity] = useState(0.25);
|
||||
const [seed, setSeed] = useState(getRandomInt(1000000));
|
||||
|
||||
|
@ -95,9 +96,10 @@ export default function Home() {
|
|||
|
||||
// Set the app state based on the prompt inputs array
|
||||
useEffect(() => {
|
||||
if (alpha <= 1) {
|
||||
if (!alphaRollover) {
|
||||
return;
|
||||
}
|
||||
setAlphaRollover(false);
|
||||
|
||||
const upNextPrompt = promptInputs[promptInputs.length - 1].prompt;
|
||||
const endPrompt = promptInputs[promptInputs.length - 2].prompt;
|
||||
|
@ -134,9 +136,7 @@ export default function Home() {
|
|||
if (newAppState != appState) {
|
||||
setAppState(newAppState);
|
||||
}
|
||||
|
||||
setAlpha(alpha - 1);
|
||||
}, [promptInputs, alpha, appState, seed]);
|
||||
}, [promptInputs, alpha, alphaRollover, appState, seed]);
|
||||
|
||||
// On any app state change, reset alpha
|
||||
useEffect(() => {
|
||||
|
@ -150,14 +150,14 @@ export default function Home() {
|
|||
}, [appState]);
|
||||
|
||||
// What to do when a new inference result is available
|
||||
const newResultCallback = (
|
||||
input: InferenceInput,
|
||||
result: InferenceResult
|
||||
) => {
|
||||
const newResultCallback = useCallback(
|
||||
(input: InferenceInput, result: InferenceResult) => {
|
||||
setInferenceResults((prevResults: InferenceResult[]) => {
|
||||
const maxResultCounter = Math.max(...prevResults.map((r) => r.counter));
|
||||
|
||||
const lastResult = prevResults.find((r) => r.counter == maxResultCounter);
|
||||
const lastResult = prevResults.find(
|
||||
(r) => r.counter == maxResultCounter
|
||||
);
|
||||
|
||||
const newCounter = lastResult ? lastResult.counter + 1 : 0;
|
||||
|
||||
|
@ -165,11 +165,18 @@ export default function Home() {
|
|||
result.input = input;
|
||||
result.played = false;
|
||||
|
||||
setAlpha(alpha + alphaVelocity);
|
||||
let newAlpha = alpha + alphaVelocity;
|
||||
if (newAlpha > 1 + 1e-3) {
|
||||
newAlpha = newAlpha - 1;
|
||||
setAlphaRollover(true);
|
||||
}
|
||||
setAlpha(newAlpha);
|
||||
|
||||
return [...prevResults, result];
|
||||
});
|
||||
};
|
||||
},
|
||||
[alpha, alphaVelocity]
|
||||
);
|
||||
|
||||
const nowPlayingCallback = (result: InferenceResult, playerTime: number) => {
|
||||
console.log(
|
||||
|
@ -240,11 +247,10 @@ export default function Home() {
|
|||
</div>
|
||||
|
||||
<div className="bg-[#0A2342] flex flex-row min-h-screen text-white">
|
||||
|
||||
<div className="absolute w-full md:w-1/3">
|
||||
<div className="absolute top-4 md:top-6 left-0 right-0 flex justify-center">
|
||||
<div
|
||||
className="text-3xl font-bold font-mono text-transparent bg-clip-text bg-gradient-to-t from-white/80 to-white/20 z-20 cursor-pointer"
|
||||
className="text-3xl font-bold font-mono text-transparent bg-clip-text bg-gradient-to-t from-white/80 to-white/70 z-20 cursor-pointer"
|
||||
onClick={() => window.open("/about", "_blank")}
|
||||
>
|
||||
[RIFFUSION]
|
||||
|
@ -262,6 +268,7 @@ export default function Home() {
|
|||
|
||||
<ModelInference
|
||||
alpha={alpha}
|
||||
alphaRollover={alphaRollover}
|
||||
seed={seed}
|
||||
appState={appState}
|
||||
promptInputs={promptInputs}
|
||||
|
|
Loading…
Reference in New Issue