Do a much better job of buffer management

This commit is contained in:
Hayk Martiros 2022-11-27 16:46:10 -08:00
parent 9c644b9223
commit 5cf6c476ab
4 changed files with 127 additions and 84 deletions

View File

@ -7,6 +7,7 @@ import { InferenceResult } from "../types";
interface AudioPlayerProps {
paused: boolean;
inferenceResults: InferenceResult[];
nowPlayingCallback: (result: InferenceResult, playerTime: number) => void;
}
/**
@ -17,6 +18,7 @@ interface AudioPlayerProps {
export default function AudioPlayer({
paused,
inferenceResults,
nowPlayingCallback,
}: AudioPlayerProps) {
// TODO(hayk): Rename
const [tonePlayer, setTonePlayer] = useState<Tone.Player>(null);
@ -37,27 +39,27 @@ export default function AudioPlayer({
return;
}
const audioUrl = inferenceResults[0].audio;
const result = inferenceResults[0];
const player = new Tone.Player(audioUrl, () => {
const player = new Tone.Player(result.audio, () => {
player.loop = true;
player.sync().start(0);
// Set up a callback to increment numClipsPlayed at the edge of each clip
const bufferLength = player.sampleTime * player.buffer.length;
console.log(bufferLength, inferenceResults[0].duration_s);
// console.log(bufferLength, result.duration_s);
// TODO(hayk): Set this callback up to vary each time using duration_s
Tone.Transport.scheduleRepeat((time) => {
// TODO(hayk): Edge of clip callback
console.log(
"Edge of clip, t = ",
Tone.Transport.getSecondsAtTime(time),
", bufferLength = ",
bufferLength,
", tone transport seconds = ",
Tone.Transport.seconds
);
// console.log(
// "Edge of clip, t = ",
// Tone.Transport.getSecondsAtTime(time),
// ", bufferLength = ",
// bufferLength,
// ", tone transport seconds = ",
// Tone.Transport.seconds
// );
setNumClipsPlayed((n) => n + 1);
}, bufferLength);
@ -72,8 +74,6 @@ export default function AudioPlayer({
// On play/pause button, play/pause the audio with the tone transport
useEffect(() => {
if (!paused) {
console.log("Play");
if (Tone.context.state == "suspended") {
Tone.context.resume();
}
@ -82,8 +82,6 @@ export default function AudioPlayer({
Tone.Transport.start();
}
} else {
console.log("Pause");
if (tonePlayer) {
Tone.Transport.pause();
}
@ -115,22 +113,16 @@ export default function AudioPlayer({
setResultCounter((c) => c + 1);
console.log("numClipsPlayed incremented ", Tone.Transport.seconds);
tonePlayer.load(result.audio).then(() => {
console.log(
"Now playing result ",
resultCounter,
", time is ",
Tone.Transport.seconds
);
// Re-jigger the transport so it stops playing old buffers. It seems like this doesn't
// introduce a gap, but watch out for that.
Tone.Transport.pause();
if (!paused) {
Tone.Transport.start();
}
const playerTime = Tone.Transport.seconds;
nowPlayingCallback(result, playerTime);
});
setPrevNumClipsPlayed(numClipsPlayed);

View File

@ -17,6 +17,8 @@ interface ModelInferenceProps {
seed: number;
appState: AppState;
promptInputs: PromptInput[];
nowPlayingResult: InferenceResult;
paused: boolean;
newResultCallback: (input: InferenceInput, result: InferenceResult) => void;
}
@ -30,6 +32,8 @@ export default function ModelInference({
seed,
appState,
promptInputs,
nowPlayingResult,
paused,
newResultCallback,
}: ModelInferenceProps) {
// Create parameters for the inference request
@ -41,6 +45,7 @@ export default function ModelInference({
const [initializedUrlParams, setInitializedUrlParams] = useState(false);
const [numRequestsMade, setNumRequestsMade] = useState(0);
const [numResponsesReceived, setNumResponsesReceived] = useState(0);
// Set initial params from URL query strings
const router = useRouter();
@ -73,58 +78,76 @@ export default function ModelInference({
}, [router.query]);
// Memoized function to kick off an inference request
const runInference = useCallback(async (
alpha: number,
seed: number,
appState: AppState,
promptInputs: PromptInput[]
) => {
const startPrompt = promptInputs[promptInputs.length - 3].prompt;
const endPrompt = promptInputs[promptInputs.length - 2].prompt;
const runInference = useCallback(
async (
alpha: number,
seed: number,
appState: AppState,
promptInputs: PromptInput[]
) => {
const startPrompt = promptInputs[promptInputs.length - 3].prompt;
const endPrompt = promptInputs[promptInputs.length - 2].prompt;
const transitioning = appState == AppState.TRANSITION;
const transitioning = appState == AppState.TRANSITION;
const inferenceInput = {
alpha: alpha,
num_inference_steps: numInferenceSteps,
seed_image_id: seedImageId,
mask_image_id: maskImageId,
start: {
prompt: startPrompt,
seed: seed,
denoising: denoising,
guidance: guidance,
},
end: {
prompt: transitioning ? endPrompt : startPrompt,
seed: transitioning ? seed : seed + 1,
denoising: denoising,
guidance: guidance,
},
};
const inferenceInput = {
alpha: alpha,
num_inference_steps: numInferenceSteps,
seed_image_id: seedImageId,
mask_image_id: maskImageId,
start: {
prompt: startPrompt,
seed: seed,
denoising: denoising,
guidance: guidance,
},
end: {
prompt: transitioning ? endPrompt : startPrompt,
seed: transitioning ? seed : seed + 1,
denoising: denoising,
guidance: guidance,
},
};
console.log("Running for input: ", inferenceInput);
setNumRequestsMade((n) => n + 1);
console.log(`Inference #${numRequestsMade}: `, {
alpha: alpha,
prompt_a: inferenceInput.start.prompt,
seed_a: inferenceInput.start.seed,
prompt_b: inferenceInput.end.prompt,
seed_b: inferenceInput.end.seed,
appState: appState,
});
const response = await fetch(SERVER_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
body: JSON.stringify(inferenceInput),
});
setNumRequestsMade((n) => n + 1);
const data = await response.json();
const response = await fetch(SERVER_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
},
body: JSON.stringify(inferenceInput),
});
newResultCallback(inferenceInput, data);
}, [denoising, guidance, maskImageId, numInferenceSteps, seedImageId, newResultCallback]);
const data = await response.json();
console.log(`Got result #${numResponsesReceived}`);
newResultCallback(inferenceInput, data);
setNumResponsesReceived((n) => n + 1);
},
[
denoising,
guidance,
maskImageId,
numInferenceSteps,
seedImageId,
newResultCallback,
]
);
// Kick off the first inference run when everything is ready.
useEffect(() => {
if (numRequestsMade > 0) {
return;
}
// Make sure things are initialized properly
if (
!initializedUrlParams ||
appState == AppState.UNINITIALIZED ||
@ -133,8 +156,18 @@ export default function ModelInference({
return;
}
console.log("First inference");
runInference(alpha, seed, appState, promptInputs);
if (numRequestsMade == 0) {
runInference(alpha, seed, appState, promptInputs);
} else if (numRequestsMade == numResponsesReceived) {
// TODO(hayk): Replace this with better buffer management
const nowPlayingCounter = nowPlayingResult ? nowPlayingResult.counter : 0;
const numAhead = numRequestsMade - nowPlayingCounter;
if (numAhead < 3) {
runInference(alpha, seed, appState, promptInputs);
}
}
}, [
initializedUrlParams,
numRequestsMade,
@ -142,16 +175,9 @@ export default function ModelInference({
seed,
appState,
promptInputs,
paused,
runInference,
]);
// Run inference on a timer.
// TODO(hayk): Fix this to properly handle state.
useInterval(() => {
// if (inferenceResults.length < maxNumInferenceResults) {
runInference(alpha, seed, appState, promptInputs);
// }
}, 4900);
return null;
}

View File

@ -1,3 +1,4 @@
import { useEffect } from "react";
import { FiPause, FiPlay } from "react-icons/fi";
interface PauseProps {
@ -10,6 +11,15 @@ export default function Pause({
setPaused,
}: PauseProps) {
// Print the state into the console
useEffect(() => {
if (paused) {
console.log("Pause");
} else {
console.log("Play");
}
}, [paused]);
var classNameCondition = ""
if (paused) {
classNameCondition="fixed z-90 top-28 right-8 bg-sky-300 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-sky-900 text-2xl hover:bg-sky-500 hover:drop-shadow-2xl"

View File

@ -50,6 +50,9 @@ export default function Home() {
[]
);
// Currently playing result, from the audio player
const [nowPlayingResult, setNowPlayingResult] = useState<InferenceResult>(null);
// Set the initial seed from the URL if available
const router = useRouter();
useEffect(() => {
@ -63,7 +66,6 @@ export default function Home() {
}
if (router.query.seed) {
console.log("setting seed");
setSeed(parseInt(router.query.seed as string));
}
@ -137,14 +139,21 @@ export default function Home() {
setAlpha(alpha + alphaVelocity);
let results = [...prevResults, result];
console.log(results);
return results;
return [...prevResults, result];
});
};
const nowPlayingCallback = (result: InferenceResult, playerTime: number) => {
console.log(
"Now playing result ",
result.counter,
", player time is ",
playerTime
);
setNowPlayingResult(result);
};
return (
<>
<Head>
@ -170,9 +179,15 @@ export default function Home() {
seed={seed}
appState={appState}
promptInputs={promptInputs}
paused={paused}
nowPlayingResult={nowPlayingResult}
newResultCallback={newResultCallback}
/>
<AudioPlayer paused={paused} inferenceResults={inferenceResults} />
<AudioPlayer
paused={paused}
inferenceResults={inferenceResults}
nowPlayingCallback={nowPlayingCallback}
/>
<PromptPanel
prompts={promptInputs}