diff --git a/components/ModelInference.tsx b/components/ModelInference.tsx index b3c36c0..418ee27 100644 --- a/components/ModelInference.tsx +++ b/components/ModelInference.tsx @@ -1,6 +1,7 @@ import { useRouter } from "next/router"; import { useCallback, useEffect, useState } from "react"; + import { AppState, InferenceInput, @@ -8,14 +9,6 @@ import { PromptInput, } from "../types"; -// TODO(hayk): Get this into a configuration. -const SERVER_URL = "http://129.146.52.68:3013/run_inference/"; -// Baseten worklet api url. Using cors-anywhere to get around CORS issues. -const BASETEN_URL = - "https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke"; -// Temporary basten API key "irritating-haircut" -const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom"; - interface ModelInferenceProps { alpha: number; seed: number; @@ -124,26 +117,17 @@ export default function ModelInference({ setNumRequestsMade((n) => n + 1); - let headers = { - "Content-Type": "application/json", - "Access-Control-Allow-Origin": "*", - }; - // Customize for baseten - const serverUrl = useBaseten ? BASETEN_URL : SERVER_URL; + const apiHandler = useBaseten ? "/api/baseten" : "/api/server"; const payload = useBaseten ? { worklet_input: inferenceInput } : inferenceInput; - if (useBaseten) { - headers["Authorization"] = `Api-Key ${BASETEN_API_KEY}`; - } - const response = await fetch(serverUrl, { + const response = await fetch(apiHandler, { method: "POST", - headers: headers, body: JSON.stringify(payload), }); - + const data = await response.json(); console.log(`Got result #${numResponsesReceived}`); @@ -154,11 +138,20 @@ export default function ModelInference({ inferenceInput, JSON.parse(data.worklet_output.model_output) ); - } else { + } + // Note, data is currently wrapped in a data field + else if (data?.data?.worklet_output?.model_output) { + newResultCallback( + inferenceInput, + JSON.parse(data.data.worklet_output.model_output) + ); + } + else { console.error("Baseten call failed: ", data); } } else { - newResultCallback(inferenceInput, data); + // Note, data is currently wrapped in a data field + newResultCallback(inferenceInput, data.data); } setNumResponsesReceived((n) => n + 1); diff --git a/pages/api/baseten.js b/pages/api/baseten.js new file mode 100644 index 0000000..35b5b4a --- /dev/null +++ b/pages/api/baseten.js @@ -0,0 +1,19 @@ +const BASETEN_URL = "https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke"; +const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom"; + +export default async function handler(req, res) { + let headers = { + "Content-Type": "application/json", + "Access-Control-Allow-Origin": "*", + "Authorization": `Api-Key ${BASETEN_API_KEY}` + }; + + const response = await fetch(BASETEN_URL, { + method: "POST", + headers: headers, + body: req.body, + }); + + const data = await response.json(); + res.status(200).json({ data }); +} \ No newline at end of file diff --git a/pages/api/server.js b/pages/api/server.js new file mode 100644 index 0000000..be01a20 --- /dev/null +++ b/pages/api/server.js @@ -0,0 +1,17 @@ +const SERVER_URL = "http://129.146.52.68:3013/run_inference/"; + +export default async function handler(req, res) { + let headers = { + "Content-Type": "application/json", + "Access-Control-Allow-Origin": "*", + }; + + const response = await fetch(SERVER_URL, { + method: "POST", + headers: headers, + body: req.body, + }); + + const data = await response.json(); + res.status(200).json({ data }); +} \ No newline at end of file