From 5ed37bf0355a0e9cf51af357cd24a7b528138751 Mon Sep 17 00:00:00 2001 From: nai-degen <44111-khanon@users.noreply.gitgud.io> Date: Sat, 8 Apr 2023 04:42:36 -0500 Subject: [PATCH] implements preliminary openai proxy --- src/auth.ts | 1 + src/keys.ts | 12 +++++----- src/kobold.ts | 6 +++++ src/openai.ts | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 6 deletions(-) create mode 100644 src/kobold.ts create mode 100644 src/openai.ts diff --git a/src/auth.ts b/src/auth.ts index 6a6edae..743afef 100644 --- a/src/auth.ts +++ b/src/auth.ts @@ -8,6 +8,7 @@ export function auth(req: Request, res: Response, next: NextFunction) { return; } if (req.headers.authorization === `Bearer ${PROXY_KEY}`) { + delete req.headers.authorization; next(); } else { res.status(401).json({ error: "Unauthorized" }); diff --git a/src/keys.ts b/src/keys.ts index f2b8259..640a5ad 100644 --- a/src/keys.ts +++ b/src/keys.ts @@ -32,7 +32,7 @@ type Key = KeySchema & { hash: string; }; -const keys: Key[] = []; +const keyPool: Key[] = []; function init() { const keyString = process.env.OPENAI_KEY; @@ -47,7 +47,7 @@ function init() { keyList = [{ key: keyString, isTrial: false, isGpt4: true }]; } for (const key of keyList) { - keys.push({ + keyPool.push({ ...key, isDisabled: false, softLimit: 0, @@ -65,15 +65,15 @@ function init() { } function list() { - return keys.map((key) => ({ + return keyPool.map((key) => ({ ...key, key: undefined, })); } -function getKey(model: string) { +function get(model: string) { const needsGpt4Key = model.startsWith("gpt-4"); - const availableKeys = keys.filter( + const availableKeys = keyPool.filter( (key) => !key.isDisabled && (!needsGpt4Key || key.isGpt4) ); if (availableKeys.length === 0) { @@ -99,4 +99,4 @@ function getKey(model: string) { return oldestKey; } -export { init, list, getKey }; +export const keys = { init, list, get }; diff --git a/src/kobold.ts b/src/kobold.ts new file mode 100644 index 0000000..afd971a --- /dev/null +++ b/src/kobold.ts @@ -0,0 +1,6 @@ +import { Request, Response, NextFunction } from "express"; + +export const kobold = (req: Request, res: Response, next: NextFunction) => { + // TODO: Implement kobold + res.status(501).json({ error: "Not implemented" }); +}; diff --git a/src/openai.ts b/src/openai.ts new file mode 100644 index 0000000..f760305 --- /dev/null +++ b/src/openai.ts @@ -0,0 +1,62 @@ +import { Request, Response, NextFunction, Router } from "express"; +import * as http from "http"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { logger } from "./logger"; +import { keys } from "./keys"; + +/** + * Modifies the request body to add a randomly selected API key. + */ +const rewriteRequest = (proxyReq: http.ClientRequest, req: Request) => { + const key = keys.get(req.body?.model || "gpt-3.5")!; + + proxyReq.setHeader("Authorization", `Bearer ${key}`); + if (req.body?.stream) { + req.body.stream = false; + const updatedBody = JSON.stringify(req.body); + proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody)); + proxyReq.write(updatedBody); + proxyReq.end(); + } +}; + +const handleResponse = ( + proxyRes: http.IncomingMessage, + req: Request, + res: Response +) => { + const { method, path } = req; + const statusCode = proxyRes.statusCode || 500; + + if (statusCode === 429) { + // TODO: Handle rate limit by temporarily removing that key from the pool + logger.warn(`OpenAI rate limit exceeded: ${method} ${path}`); + } else if (statusCode >= 400) { + logger.warn(`OpenAI error: ${method} ${path} ${statusCode}`); + } else { + logger.info(`OpenAI request: ${method} ${path} ${statusCode}`); + } + + proxyRes.pipe(res); +}; + +const openaiProxy = createProxyMiddleware({ + target: "https://api.openai.com", + changeOrigin: true, + onProxyReq: rewriteRequest, + onProxyRes: handleResponse, + selfHandleResponse: true, + pathRewrite: { + "^/proxy/openai": "", + }, +}); + +export const openaiRouter = Router(); +openaiRouter.post("/v1/chat/completions", openaiProxy); +// openaiRouter.post("/v1/completions", openaiProxy); +// openaiRouter.get("/v1/models", handleModels); +// openaiRouter.get("/dashboard/billing/usage, handleUsage); +openaiRouter.use((req, res) => { + logger.warn(`Blocked openai proxy request: ${req.method} ${req.path}`); + res.status(404).json({ error: "Not found" }); +});