From d5bd8167c29facf167288e132e362746762543bb Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Mon, 12 Dec 2022 18:23:03 -0800 Subject: [PATCH] Environment variables for model inference and updated readme --- README.md | 41 +++++++++++++++++++++-------------- components/ModelInference.tsx | 12 +++++----- pages/about.tsx | 2 +- pages/api/baseten.js | 9 +++----- pages/api/server.js | 6 ++--- pages/index.tsx | 14 +++++------- 6 files changed, 44 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 4124125..4b9199f 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,22 @@ # Riffusion App +Riffusion generates audio using stable diffusion. See https://www.riffusion.com/about for details. + +* Web app: https://github.com/hmartiro/riffusion-app +* Inference backend: https://github.com/hmartiro/riffusion-inference +* Model checkpoint: TODO + +## Run + This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). -## Getting Started - -Prerequisites: +Install: ```bash npm install ``` -Then, run the development server: +Run the development server: ```bash npm run dev @@ -18,25 +24,28 @@ npm run dev yarn dev ``` -Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. +Open [http://localhost:3000](http://localhost:3000) with your browser to see the app. -You can start editing the page by modifying `pages/index.js`. The page auto-updates as you edit the file. - -[API routes](https://nextjs.org/docs/api-routes/introduction) can be accessed on [http://localhost:3000/api/hello](http://localhost:3000/api/hello). This endpoint can be edited in `pages/api/hello.js`. +The app home is at `pages/index.js`. The page auto-updates as you edit the file. The about page is at `pages/about.tsx`. The `pages/api` directory is mapped to `/api/*`. Files in this directory are treated as [API routes](https://nextjs.org/docs/api-routes/introduction) instead of React pages. -## Learn More +## Inference Server -To learn more about Next.js, take a look at the following resources: +To actually generate model outputs, we need a model backend that responds to inference requests via API. If you have a large GPU that can run stable diffusion in under five seconds, clone and run the instructions in the [inference server](https://github.com/hmartiro/riffusion-inference) to run the Flask app. -- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. -- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. +This app also has a configuration to run with [BaseTen](https://www.baseten.co/) for auto-scaling and load balancing. To use BaseTen, you need an API key. -You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome! +To configure these backends, add a `.env.local` file: -## Deploy on Vercel +``` +# URL to your flask instance +RIFFUSION_FLASK_URL=http://localhost:3013/run_inference/ -The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js. +# Whether to use baseten as the model backend +NEXT_PUBLIC_RIFFUSION_USE_BASETEN=false -Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details. +# If using BaseTen, the URL and API key +RIFFUSION_BASETEN_URL=https://app.baseten.co/applications/XXX +RIFFUSION_BASETEN_API_KEY=XXX +``` diff --git a/components/ModelInference.tsx b/components/ModelInference.tsx index 418ee27..007fdeb 100644 --- a/components/ModelInference.tsx +++ b/components/ModelInference.tsx @@ -1,7 +1,6 @@ import { useRouter } from "next/router"; import { useCallback, useEffect, useState } from "react"; - import { AppState, InferenceInput, @@ -44,6 +43,10 @@ export default function ModelInference({ const [numRequestsMade, setNumRequestsMade] = useState(0); const [numResponsesReceived, setNumResponsesReceived] = useState(0); + useEffect(() => { + console.log("Using baseten: ", useBaseten); + }, [useBaseten]); + // Set initial params from URL query strings const router = useRouter(); useEffect(() => { @@ -127,7 +130,7 @@ export default function ModelInference({ method: "POST", body: JSON.stringify(payload), }); - + const data = await response.json(); console.log(`Got result #${numResponsesReceived}`); @@ -145,12 +148,11 @@ export default function ModelInference({ inferenceInput, JSON.parse(data.data.worklet_output.model_output) ); - } - else { + } else { console.error("Baseten call failed: ", data); } } else { - // Note, data is currently wrapped in a data field + // Note, data is currently wrapped in a data field newResultCallback(inferenceInput, data.data); } diff --git a/pages/about.tsx b/pages/about.tsx index b59504c..6307410 100644 --- a/pages/about.tsx +++ b/pages/about.tsx @@ -401,7 +401,7 @@ export default function Home() { The app communicates over an API to run the inference calls on a GPU server. We built a flask server for testing, and deployed to production on Baseten for - autoscaling and load balancing. + auto-scaling and load balancing.

The web app code is at{" "} diff --git a/pages/api/baseten.js b/pages/api/baseten.js index 35b5b4a..a8573e9 100644 --- a/pages/api/baseten.js +++ b/pages/api/baseten.js @@ -1,14 +1,11 @@ -const BASETEN_URL = "https://app.baseten.co/applications/2qREaXP/production/worklets/mP7KkLP/invoke"; -const BASETEN_API_KEY = "JocxKmyo.g0JreAA8dZy5F20PdMxGAV34a4VGGpom"; - export default async function handler(req, res) { let headers = { "Content-Type": "application/json", "Access-Control-Allow-Origin": "*", - "Authorization": `Api-Key ${BASETEN_API_KEY}` + "Authorization": `Api-Key ${process.env.RIFFUSION_BASETEN_API_KEY}` }; - const response = await fetch(BASETEN_URL, { + const response = await fetch(process.env.RIFFUSION_BASETEN_URL, { method: "POST", headers: headers, body: req.body, @@ -16,4 +13,4 @@ export default async function handler(req, res) { const data = await response.json(); res.status(200).json({ data }); -} \ No newline at end of file +} diff --git a/pages/api/server.js b/pages/api/server.js index be01a20..4ac8532 100644 --- a/pages/api/server.js +++ b/pages/api/server.js @@ -1,12 +1,10 @@ -const SERVER_URL = "http://129.146.52.68:3013/run_inference/"; - export default async function handler(req, res) { let headers = { "Content-Type": "application/json", "Access-Control-Allow-Origin": "*", }; - const response = await fetch(SERVER_URL, { + const response = await fetch(process.env.RIFFUSION_FLASK_URL, { method: "POST", headers: headers, body: req.body, @@ -14,4 +12,4 @@ export default async function handler(req, res) { const data = await response.json(); res.status(200).json({ data }); -} \ No newline at end of file +} diff --git a/pages/index.tsx b/pages/index.tsx index c365164..ead7894 100644 --- a/pages/index.tsx +++ b/pages/index.tsx @@ -3,9 +3,7 @@ import { useEffect, useState } from "react"; import * as Tone from "tone"; import AudioPlayer from "../components/AudioPlayer"; -import DebugView from "../components/DebugView"; import PageHead from "../components/PageHead"; -import Info from "../components/Info"; import Share from "../components/Share"; import Settings from "../components/Settings"; import ModelInference from "../components/ModelInference"; @@ -215,9 +213,9 @@ export default function Home() { promptInputs={promptInputs} nowPlayingResult={nowPlayingResult} newResultCallback={newResultCallback} - useBaseten={true} + useBaseten={process.env.NEXT_PUBLIC_RIFFUSION_USE_BASETEN == "true"} /> - + -