Merge branch 'main' into draft-settingsViewDaisy

This commit is contained in:
Seth Forsgren 2022-12-12 09:51:02 -08:00
commit 01134dc364
5 changed files with 54 additions and 25 deletions

View File

@ -1,6 +1,7 @@
import { useRouter } from "next/router"; import { useRouter } from "next/router";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { import {
AppState, AppState,
InferenceInput, InferenceInput,
@ -8,14 +9,6 @@ import {
PromptInput, PromptInput,
} from "../types"; } from "../types";
// TODO(hayk): Get this into a configuration.
const SERVER_URL = "http://129.146.52.68:3013/run_inference/";
// Baseten worklet api url. Using cors-anywhere to get around CORS issues.
const BASETEN_URL =
"https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke";
// Temporary basten API key "irritating-haircut"
const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom";
interface ModelInferenceProps { interface ModelInferenceProps {
alpha: number; alpha: number;
seed: number; seed: number;
@ -124,23 +117,14 @@ export default function ModelInference({
setNumRequestsMade((n) => n + 1); setNumRequestsMade((n) => n + 1);
let headers = {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
};
// Customize for baseten // Customize for baseten
const serverUrl = useBaseten ? BASETEN_URL : SERVER_URL; const apiHandler = useBaseten ? "/api/baseten" : "/api/server";
const payload = useBaseten const payload = useBaseten
? { worklet_input: inferenceInput } ? { worklet_input: inferenceInput }
: inferenceInput; : inferenceInput;
if (useBaseten) {
headers["Authorization"] = `Api-Key ${BASETEN_API_KEY}`;
}
const response = await fetch(serverUrl, { const response = await fetch(apiHandler, {
method: "POST", method: "POST",
headers: headers,
body: JSON.stringify(payload), body: JSON.stringify(payload),
}); });
@ -154,11 +138,20 @@ export default function ModelInference({
inferenceInput, inferenceInput,
JSON.parse(data.worklet_output.model_output) JSON.parse(data.worklet_output.model_output)
); );
} else { }
// Note, data is currently wrapped in a data field
else if (data?.data?.worklet_output?.model_output) {
newResultCallback(
inferenceInput,
JSON.parse(data.data.worklet_output.model_output)
);
}
else {
console.error("Baseten call failed: ", data); console.error("Baseten call failed: ", data);
} }
} else { } else {
newResultCallback(inferenceInput, data); // Note, data is currently wrapped in a data field
newResultCallback(inferenceInput, data.data);
} }
setNumResponsesReceived((n) => n + 1); setNumResponsesReceived((n) => n + 1);

View File

@ -22,7 +22,7 @@ export default function Pause({
var classNameCondition = "" var classNameCondition = ""
if (paused) { if (paused) {
classNameCondition="animate-pulse fixed z-20 top-4 right-4 md:top-8 md:right-8 bg-sky-400 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-white text-2xl hover:bg-sky-500 focus:ring-4 focus:outline-none focus:ring-sky-600 hover:drop-shadow-2xl" classNameCondition="animate-pulse fixed z-20 top-4 right-4 md:top-8 md:right-8 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-white text-2xl bg-red-500 hover:bg-red-600 ring-4 ring-red-700 focus:outline-none hover:drop-shadow-2xl"
} else { } else {
classNameCondition="fixed z-20 top-4 right-4 md:top-8 md:right-8 bg-slate-100 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-sky-900 text-2xl hover:text-white hover:bg-sky-600 hover:drop-shadow-2xl" classNameCondition="fixed z-20 top-4 right-4 md:top-8 md:right-8 bg-slate-100 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-sky-900 text-2xl hover:text-white hover:bg-sky-600 hover:drop-shadow-2xl"
} }

View File

@ -79,15 +79,15 @@ export default function Share({
} }
// TODO: Consider start or end here. End is something the the user hasn't actually heard yet. Start is perhaps a previous prompt than where the user is headed // TODO: Consider start or end here. End is something the the user hasn't actually heard yet. Start is perhaps a previous prompt than where the user is headed
// TODO: Consider only including in the link the things that are different from the default values
prompt = selectedInput.start.prompt prompt = selectedInput.start.prompt
seed = selectedInput.start.seed seed = selectedInput.start.seed
denoising = selectedInput.start.denoising denoising = selectedInput.start.denoising
maskImageId = selectedInput.mask_image_id maskImageId = selectedInput.mask_image_id
seedImageId = nowPlayingResult.input.seed_image_id
// TODO, selectively add these based on whether we give user option to change them // TODO, selectively add these based on whether we give user option to change them
// seedImageId = nowPlayingResult.input.seed_image_id
// guidance = nowPlayingResult.input.guidance // guidance = nowPlayingResult.input.guidance
// numInferenceSteps = nowPlayingResult.input.num_inference_steps // numInferenceSteps = nowPlayingResult.input.num_inference_steps
// alphaVelocity = nowPlayingResult.input.alpha_velocity // alphaVelocity = nowPlayingResult.input.alpha_velocity

19
pages/api/baseten.js Normal file
View File

@ -0,0 +1,19 @@
const BASETEN_URL = "https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke";
const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom";
export default async function handler(req, res) {
let headers = {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
"Authorization": `Api-Key ${BASETEN_API_KEY}`
};
const response = await fetch(BASETEN_URL, {
method: "POST",
headers: headers,
body: req.body,
});
const data = await response.json();
res.status(200).json({ data });
}

17
pages/api/server.js Normal file
View File

@ -0,0 +1,17 @@
const SERVER_URL = "http://129.146.52.68:3013/run_inference/";
export default async function handler(req, res) {
let headers = {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
};
const response = await fetch(SERVER_URL, {
method: "POST",
headers: headers,
body: req.body,
});
const data = await response.json();
res.status(200).json({ data });
}