Fix alpha > 1 issue

This commit is contained in:
Hayk Martiros 2022-12-13 21:10:43 -08:00
parent d2aea7f2e3
commit 37bdf4fe61
4 changed files with 48 additions and 33 deletions

View File

@ -10,6 +10,7 @@ import {
interface ModelInferenceProps {
alpha: number;
alphaRollover: boolean;
seed: number;
appState: AppState;
promptInputs: PromptInput[];
@ -27,6 +28,7 @@ interface ModelInferenceProps {
*/
export default function ModelInference({
alpha,
alphaRollover,
seed,
appState,
promptInputs,
@ -176,6 +178,11 @@ export default function ModelInference({
return;
}
// Wait for alpha rollover to resolve.
if (alphaRollover) {
return;
}
if (numRequestsMade == 0) {
// Kick off the first request
runInference(alpha, seed, appState, promptInputs);
@ -193,6 +200,7 @@ export default function ModelInference({
}, [
initializedUrlParams,
alpha,
alphaRollover,
seed,
appState,
promptInputs,

View File

@ -395,15 +395,15 @@ export default function Home() {
</p>
<p className="mt-3">
The app communicates over an API to run the inference calls on a GPU
server. We used{" "}
<a href="https://truss.baseten.co">Truss</a>{" "}
to package the model and test it locally before
deploying it to Baseten which provided GPU-backed inference, auto-scaling,
and observability. We used NVIDIA A10Gs in production.
server. We used <a href="https://truss.baseten.co">Truss</a> to
package the model and test it locally before deploying it to Baseten
which provided GPU-backed inference, auto-scaling, and
observability. We used NVIDIA A10Gs in production.
</p>
<p className="mt-3">
If you have a GPU powerful enough to generate stable diffusion
results in under five seconds, you can run the experience locally.
results in under five seconds, you can run the experience locally
using our test flask server.
</p>
<br />
<h2 className="pt-10 pb-5 text-3xl font-bold">Code</h2>

View File

@ -1,5 +1,5 @@
import { useRouter } from "next/router";
import { useEffect, useState } from "react";
import { useCallback, useEffect, useState } from "react";
import * as Tone from "tone";
import AudioPlayer from "../components/AudioPlayer";
@ -41,6 +41,7 @@ export default function Home() {
// Current interpolation parameters
const [alpha, setAlpha] = useState(0.0);
const [alphaRollover, setAlphaRollover] = useState(false);
const [alphaVelocity, setAlphaVelocity] = useState(0.25);
const [seed, setSeed] = useState(getRandomInt(1000000));
@ -95,9 +96,10 @@ export default function Home() {
// Set the app state based on the prompt inputs array
useEffect(() => {
if (alpha <= 1) {
if (!alphaRollover) {
return;
}
setAlphaRollover(false);
const upNextPrompt = promptInputs[promptInputs.length - 1].prompt;
const endPrompt = promptInputs[promptInputs.length - 2].prompt;
@ -134,9 +136,7 @@ export default function Home() {
if (newAppState != appState) {
setAppState(newAppState);
}
setAlpha(alpha - 1);
}, [promptInputs, alpha, appState, seed]);
}, [promptInputs, alpha, alphaRollover, appState, seed]);
// On any app state change, reset alpha
useEffect(() => {
@ -150,14 +150,14 @@ export default function Home() {
}, [appState]);
// What to do when a new inference result is available
const newResultCallback = (
input: InferenceInput,
result: InferenceResult
) => {
const newResultCallback = useCallback(
(input: InferenceInput, result: InferenceResult) => {
setInferenceResults((prevResults: InferenceResult[]) => {
const maxResultCounter = Math.max(...prevResults.map((r) => r.counter));
const lastResult = prevResults.find((r) => r.counter == maxResultCounter);
const lastResult = prevResults.find(
(r) => r.counter == maxResultCounter
);
const newCounter = lastResult ? lastResult.counter + 1 : 0;
@ -165,11 +165,18 @@ export default function Home() {
result.input = input;
result.played = false;
setAlpha(alpha + alphaVelocity);
let newAlpha = alpha + alphaVelocity;
if (newAlpha > 1 + 1e-3) {
newAlpha = newAlpha - 1;
setAlphaRollover(true);
}
setAlpha(newAlpha);
return [...prevResults, result];
});
};
},
[alpha, alphaVelocity]
);
const nowPlayingCallback = (result: InferenceResult, playerTime: number) => {
console.log(
@ -240,11 +247,10 @@ export default function Home() {
</div>
<div className="bg-[#0A2342] flex flex-row min-h-screen text-white">
<div className="absolute w-full md:w-1/3">
<div className="absolute top-4 md:top-6 left-0 right-0 flex justify-center">
<div
className="text-3xl font-bold font-mono text-transparent bg-clip-text bg-gradient-to-t from-white/80 to-white/20 z-20 cursor-pointer"
className="text-3xl font-bold font-mono text-transparent bg-clip-text bg-gradient-to-t from-white/80 to-white/70 z-20 cursor-pointer"
onClick={() => window.open("/about", "_blank")}
>
[RIFFUSION]
@ -262,6 +268,7 @@ export default function Home() {
<ModelInference
alpha={alpha}
alphaRollover={alphaRollover}
seed={seed}
appState={appState}
promptInputs={promptInputs}