Support baseten fetching

This commit is contained in:
Hayk Martiros 2022-12-11 15:55:25 -08:00
parent 8781df9795
commit 0d8626de30
2 changed files with 36 additions and 20 deletions

View File

@ -11,9 +11,10 @@ import {
// TODO(hayk): Get this into a configuration.
const SERVER_URL = "http://129.146.52.68:3013/run_inference/";
// Baseten worklet api url. Using cors-anywhere to get around CORS issues.
const BASETEN_URL = "http://cors-anywhere.herokuapp.com/https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke";
const BASETEN_URL =
"http://cors-anywhere.herokuapp.com/https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke";
// Temporary basten API key "irritating-haircut"
const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom"
const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom";
interface ModelInferenceProps {
alpha: number;
@ -22,6 +23,7 @@ interface ModelInferenceProps {
promptInputs: PromptInput[];
nowPlayingResult: InferenceResult;
newResultCallback: (input: InferenceInput, result: InferenceResult) => void;
useBaseten: boolean;
}
/**
@ -36,6 +38,7 @@ export default function ModelInference({
promptInputs,
nowPlayingResult,
newResultCallback,
useBaseten,
}: ModelInferenceProps) {
// Create parameters for the inference request
const [denoising, setDenoising] = useState(0.75);
@ -121,31 +124,43 @@ export default function ModelInference({
setNumRequestsMade((n) => n + 1);
// Server API call
// const ServerResponse = await fetch(SERVER_URL, {
// method: "POST",
// headers: {
// "Content-Type": "application/json",
// "Access-Control-Allow-Origin": "*",
// },
// body: JSON.stringify(inferenceInput),
// });
let headers = {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
};
// Baseten worklet API call
const response = await fetch(BASETEN_URL, {
// Customize for baseten
const serverUrl = useBaseten ? BASETEN_URL : SERVER_URL;
const payload = useBaseten
? { worklet_input: inferenceInput }
: inferenceInput;
if (useBaseten) {
headers["Authorization"] = `Api-Key ${BASETEN_API_KEY}`;
}
const response = await fetch(serverUrl, {
method: "POST",
headers: {
"Authorization": "Api-Key JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom",
"Access-Control-Allow-Origin": "*",
},
// add the body of the request with {"worklet_input": {}}
body: JSON.stringify({worklet_input: inferenceInput}),
headers: headers,
body: JSON.stringify(payload),
});
const data = await response.json();
console.log(`Got result #${numResponsesReceived}`);
newResultCallback(inferenceInput, data);
if (useBaseten) {
if (data?.worklet_output?.model_output) {
newResultCallback(
inferenceInput,
JSON.parse(data.worklet_output.model_output)
);
} else {
console.error("Baseten call failed: ", data);
}
} else {
newResultCallback(inferenceInput, data);
}
setNumResponsesReceived((n) => n + 1);
},
[

View File

@ -210,6 +210,7 @@ export default function Home() {
promptInputs={promptInputs}
nowPlayingResult={nowPlayingResult}
newResultCallback={newResultCallback}
useBaseten={false}
/>
<AudioPlayer
paused={paused}