Support baseten fetching
This commit is contained in:
parent
8781df9795
commit
0d8626de30
|
@ -11,9 +11,10 @@ import {
|
|||
// TODO(hayk): Get this into a configuration.
|
||||
const SERVER_URL = "http://129.146.52.68:3013/run_inference/";
|
||||
// Baseten worklet api url. Using cors-anywhere to get around CORS issues.
|
||||
const BASETEN_URL = "http://cors-anywhere.herokuapp.com/https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke";
|
||||
const BASETEN_URL =
|
||||
"http://cors-anywhere.herokuapp.com/https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke";
|
||||
// Temporary basten API key "irritating-haircut"
|
||||
const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom"
|
||||
const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom";
|
||||
|
||||
interface ModelInferenceProps {
|
||||
alpha: number;
|
||||
|
@ -22,6 +23,7 @@ interface ModelInferenceProps {
|
|||
promptInputs: PromptInput[];
|
||||
nowPlayingResult: InferenceResult;
|
||||
newResultCallback: (input: InferenceInput, result: InferenceResult) => void;
|
||||
useBaseten: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -36,6 +38,7 @@ export default function ModelInference({
|
|||
promptInputs,
|
||||
nowPlayingResult,
|
||||
newResultCallback,
|
||||
useBaseten,
|
||||
}: ModelInferenceProps) {
|
||||
// Create parameters for the inference request
|
||||
const [denoising, setDenoising] = useState(0.75);
|
||||
|
@ -121,31 +124,43 @@ export default function ModelInference({
|
|||
|
||||
setNumRequestsMade((n) => n + 1);
|
||||
|
||||
// Server API call
|
||||
// const ServerResponse = await fetch(SERVER_URL, {
|
||||
// method: "POST",
|
||||
// headers: {
|
||||
// "Content-Type": "application/json",
|
||||
// "Access-Control-Allow-Origin": "*",
|
||||
// },
|
||||
// body: JSON.stringify(inferenceInput),
|
||||
// });
|
||||
let headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
};
|
||||
|
||||
// Baseten worklet API call
|
||||
const response = await fetch(BASETEN_URL, {
|
||||
// Customize for baseten
|
||||
const serverUrl = useBaseten ? BASETEN_URL : SERVER_URL;
|
||||
const payload = useBaseten
|
||||
? { worklet_input: inferenceInput }
|
||||
: inferenceInput;
|
||||
if (useBaseten) {
|
||||
headers["Authorization"] = `Api-Key ${BASETEN_API_KEY}`;
|
||||
}
|
||||
|
||||
const response = await fetch(serverUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Authorization": "Api-Key JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
},
|
||||
// add the body of the request with {"worklet_input": {}}
|
||||
body: JSON.stringify({worklet_input: inferenceInput}),
|
||||
headers: headers,
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
console.log(`Got result #${numResponsesReceived}`);
|
||||
newResultCallback(inferenceInput, data);
|
||||
|
||||
if (useBaseten) {
|
||||
if (data?.worklet_output?.model_output) {
|
||||
newResultCallback(
|
||||
inferenceInput,
|
||||
JSON.parse(data.worklet_output.model_output)
|
||||
);
|
||||
} else {
|
||||
console.error("Baseten call failed: ", data);
|
||||
}
|
||||
} else {
|
||||
newResultCallback(inferenceInput, data);
|
||||
}
|
||||
|
||||
setNumResponsesReceived((n) => n + 1);
|
||||
},
|
||||
[
|
||||
|
|
|
@ -210,6 +210,7 @@ export default function Home() {
|
|||
promptInputs={promptInputs}
|
||||
nowPlayingResult={nowPlayingResult}
|
||||
newResultCallback={newResultCallback}
|
||||
useBaseten={false}
|
||||
/>
|
||||
<AudioPlayer
|
||||
paused={paused}
|
||||
|
|
Loading…
Reference in New Issue