diff --git a/components/ModelInference.tsx b/components/ModelInference.tsx index b3c36c0..418ee27 100644 --- a/components/ModelInference.tsx +++ b/components/ModelInference.tsx @@ -1,6 +1,7 @@ import { useRouter } from "next/router"; import { useCallback, useEffect, useState } from "react"; + import { AppState, InferenceInput, @@ -8,14 +9,6 @@ import { PromptInput, } from "../types"; -// TODO(hayk): Get this into a configuration. -const SERVER_URL = "http://129.146.52.68:3013/run_inference/"; -// Baseten worklet api url. Using cors-anywhere to get around CORS issues. -const BASETEN_URL = - "https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke"; -// Temporary basten API key "irritating-haircut" -const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom"; - interface ModelInferenceProps { alpha: number; seed: number; @@ -124,26 +117,17 @@ export default function ModelInference({ setNumRequestsMade((n) => n + 1); - let headers = { - "Content-Type": "application/json", - "Access-Control-Allow-Origin": "*", - }; - // Customize for baseten - const serverUrl = useBaseten ? BASETEN_URL : SERVER_URL; + const apiHandler = useBaseten ? "/api/baseten" : "/api/server"; const payload = useBaseten ? { worklet_input: inferenceInput } : inferenceInput; - if (useBaseten) { - headers["Authorization"] = `Api-Key ${BASETEN_API_KEY}`; - } - const response = await fetch(serverUrl, { + const response = await fetch(apiHandler, { method: "POST", - headers: headers, body: JSON.stringify(payload), }); - + const data = await response.json(); console.log(`Got result #${numResponsesReceived}`); @@ -154,11 +138,20 @@ export default function ModelInference({ inferenceInput, JSON.parse(data.worklet_output.model_output) ); - } else { + } + // Note, data is currently wrapped in a data field + else if (data?.data?.worklet_output?.model_output) { + newResultCallback( + inferenceInput, + JSON.parse(data.data.worklet_output.model_output) + ); + } + else { console.error("Baseten call failed: ", data); } } else { - newResultCallback(inferenceInput, data); + // Note, data is currently wrapped in a data field + newResultCallback(inferenceInput, data.data); } setNumResponsesReceived((n) => n + 1); diff --git a/components/Pause.tsx b/components/Pause.tsx index 8ef75d4..7b879bb 100644 --- a/components/Pause.tsx +++ b/components/Pause.tsx @@ -22,7 +22,7 @@ export default function Pause({ var classNameCondition = "" if (paused) { - classNameCondition="animate-pulse fixed z-20 top-4 right-4 md:top-8 md:right-8 bg-sky-400 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-white text-2xl hover:bg-sky-500 focus:ring-4 focus:outline-none focus:ring-sky-600 hover:drop-shadow-2xl" + classNameCondition="animate-pulse fixed z-20 top-4 right-4 md:top-8 md:right-8 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-white text-2xl bg-red-500 hover:bg-red-600 ring-4 ring-red-700 focus:outline-none hover:drop-shadow-2xl" } else { classNameCondition="fixed z-20 top-4 right-4 md:top-8 md:right-8 bg-slate-100 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-sky-900 text-2xl hover:text-white hover:bg-sky-600 hover:drop-shadow-2xl" } diff --git a/components/Share.tsx b/components/Share.tsx index 4cd4263..d5ded67 100644 --- a/components/Share.tsx +++ b/components/Share.tsx @@ -79,15 +79,15 @@ export default function Share({ } // TODO: Consider start or end here. End is something the the user hasn't actually heard yet. Start is perhaps a previous prompt than where the user is headed - + // TODO: Consider only including in the link the things that are different from the default values prompt = selectedInput.start.prompt seed = selectedInput.start.seed denoising = selectedInput.start.denoising maskImageId = selectedInput.mask_image_id + seedImageId = nowPlayingResult.input.seed_image_id // TODO, selectively add these based on whether we give user option to change them - // seedImageId = nowPlayingResult.input.seed_image_id // guidance = nowPlayingResult.input.guidance // numInferenceSteps = nowPlayingResult.input.num_inference_steps // alphaVelocity = nowPlayingResult.input.alpha_velocity diff --git a/pages/api/baseten.js b/pages/api/baseten.js new file mode 100644 index 0000000..35b5b4a --- /dev/null +++ b/pages/api/baseten.js @@ -0,0 +1,19 @@ +const BASETEN_URL = "https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke"; +const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom"; + +export default async function handler(req, res) { + let headers = { + "Content-Type": "application/json", + "Access-Control-Allow-Origin": "*", + "Authorization": `Api-Key ${BASETEN_API_KEY}` + }; + + const response = await fetch(BASETEN_URL, { + method: "POST", + headers: headers, + body: req.body, + }); + + const data = await response.json(); + res.status(200).json({ data }); +} \ No newline at end of file diff --git a/pages/api/server.js b/pages/api/server.js new file mode 100644 index 0000000..be01a20 --- /dev/null +++ b/pages/api/server.js @@ -0,0 +1,17 @@ +const SERVER_URL = "http://129.146.52.68:3013/run_inference/"; + +export default async function handler(req, res) { + let headers = { + "Content-Type": "application/json", + "Access-Control-Allow-Origin": "*", + }; + + const response = await fetch(SERVER_URL, { + method: "POST", + headers: headers, + body: req.body, + }); + + const data = await response.json(); + res.status(200).json({ data }); +} \ No newline at end of file