From e03f3d48ddf1e23d4010d124b804b17ac675cc66 Mon Sep 17 00:00:00 2001 From: nai-degen <44111-khanon@users.noreply.gitgud.io> Date: Tue, 9 May 2023 23:11:57 +0000 Subject: [PATCH] Implements request queueing (khanon/oai-reverse-proxy!6) --- .env.example | 2 + info-page.md | 3 - src/config.ts | 26 +- src/info-page.ts | 97 +++-- src/key-management/key-checker.ts | 3 - src/key-management/key-pool.ts | 244 ++++++++++-- .../middleware/request/check-streaming.ts | 5 +- .../response/handle-streamed-response.ts | 33 +- src/proxy/middleware/response/index.ts | 186 +++++---- src/proxy/openai.ts | 4 +- src/proxy/queue.ts | 367 ++++++++++++++++++ src/proxy/rate-limit.ts | 3 +- src/server.ts | 23 +- src/types/custom.d.ts | 6 + 14 files changed, 853 insertions(+), 149 deletions(-) delete mode 100644 info-page.md create mode 100644 src/proxy/queue.ts diff --git a/.env.example b/.env.example index 1b6cdf7..95eac7f 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,8 @@ # REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy." # REJECT_SAMPLE_RATE=0.2 # CHECK_KEYS=false +# QUOTA_DISPLAY_MODE=full +# QUEUE_MODE=fair # Note: CHECK_KEYS is disabled by default in local development mode, but enabled # by default in production mode. diff --git a/info-page.md b/info-page.md deleted file mode 100644 index ed1d22f..0000000 --- a/info-page.md +++ /dev/null @@ -1,3 +0,0 @@ - - -# OAI Reverse Proxy diff --git a/src/config.ts b/src/config.ts index 144807c..7c727d5 100644 --- a/src/config.ts +++ b/src/config.ts @@ -3,7 +3,8 @@ dotenv.config(); const isDev = process.env.NODE_ENV !== "production"; -type PROMPT_LOGGING_BACKEND = "google_sheets"; +type PromptLoggingBackend = "google_sheets"; +export type DequeueMode = "fair" | "random" | "none"; type Config = { /** The port the proxy server will listen on. */ @@ -25,17 +26,29 @@ type Config = { /** Pino log level. */ logLevel?: "debug" | "info" | "warn" | "error"; /** Whether prompts and responses should be logged to persistent storage. */ - promptLogging?: boolean; // TODO: Implement prompt logging once we have persistent storage. + promptLogging?: boolean; /** Which prompt logging backend to use. */ - promptLoggingBackend?: PROMPT_LOGGING_BACKEND; + promptLoggingBackend?: PromptLoggingBackend; /** Base64-encoded Google Sheets API key. */ googleSheetsKey?: string; /** Google Sheets spreadsheet ID. */ googleSheetsSpreadsheetId?: string; /** Whether to periodically check keys for usage and validity. */ checkKeys?: boolean; - /** Whether to allow streaming completions. This is usually fine but can cause issues on some deployments. */ - allowStreaming?: boolean; + /** + * How to display quota information on the info page. + * 'none' - Hide quota information + * 'simple' - Display quota information as a percentage + * 'full' - Display quota information as usage against total capacity + */ + quotaDisplayMode: "none" | "simple" | "full"; + /** + * Which request queueing strategy to use when keys are over their rate limit. + * 'fair' - Requests are serviced in the order they were received (default) + * 'random' - Requests are serviced randomly + * 'none' - Requests are not queued and users have to retry manually + */ + queueMode: DequeueMode; }; // To change configs, create a file called .env in the root directory. @@ -54,6 +67,7 @@ export const config: Config = { ), logLevel: getEnvWithDefault("LOG_LEVEL", "info"), checkKeys: getEnvWithDefault("CHECK_KEYS", !isDev), + quotaDisplayMode: getEnvWithDefault("QUOTA_DISPLAY_MODE", "full"), promptLogging: getEnvWithDefault("PROMPT_LOGGING", false), promptLoggingBackend: getEnvWithDefault("PROMPT_LOGGING_BACKEND", undefined), googleSheetsKey: getEnvWithDefault("GOOGLE_SHEETS_KEY", undefined), @@ -61,7 +75,7 @@ export const config: Config = { "GOOGLE_SHEETS_SPREADSHEET_ID", undefined ), - allowStreaming: getEnvWithDefault("ALLOW_STREAMING", true), + queueMode: getEnvWithDefault("QUEUE_MODE", "fair"), } as const; export const SENSITIVE_KEYS: (keyof Config)[] = [ diff --git a/src/info-page.ts b/src/info-page.ts index 008107b..8f8ab23 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -4,60 +4,89 @@ import showdown from "showdown"; import { config, listConfig } from "./config"; import { keyPool } from "./key-management"; import { getUniqueIps } from "./proxy/rate-limit"; +import { getAverageWaitTime, getQueueLength } from "./proxy/queue"; + +const INFO_PAGE_TTL = 5000; +let infoPageHtml: string | undefined; +let infoPageLastUpdated = 0; export const handleInfoPage = (req: Request, res: Response) => { + if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) { + res.send(infoPageHtml); + return; + } + // Huggingface puts spaces behind some cloudflare ssl proxy, so `req.protocol` is `http` but the correct URL is actually `https` const host = req.get("host"); const isHuggingface = host?.includes("hf.space"); const protocol = isHuggingface ? "https" : req.protocol; - res.send(getInfoPageHtml(protocol + "://" + host)); + res.send(cacheInfoPageHtml(protocol + "://" + host)); }; -function getInfoPageHtml(host: string) { +function cacheInfoPageHtml(host: string) { const keys = keyPool.list(); - let keyInfo: Record = { - all: keys.length, - active: keys.filter((k) => !k.isDisabled).length, - }; + let keyInfo: Record = { all: keys.length }; if (keyPool.anyUnchecked()) { const uncheckedKeys = keys.filter((k) => !k.lastChecked); keyInfo = { ...keyInfo, + active: keys.filter((k) => !k.isDisabled).length, status: `Still checking ${uncheckedKeys.length} keys...`, }; } else if (config.checkKeys) { + const trialKeys = keys.filter((k) => k.isTrial); + const turboKeys = keys.filter((k) => !k.isGpt4 && !k.isDisabled); + const gpt4Keys = keys.filter((k) => k.isGpt4 && !k.isDisabled); + + const quota: Record = { turbo: "", gpt4: "" }; const hasGpt4 = keys.some((k) => k.isGpt4); + + if (config.quotaDisplayMode === "full") { + quota.turbo = `${keyPool.usageInUsd()} (${Math.round( + keyPool.remainingQuota() * 100 + )}% remaining)`; + quota.gpt4 = `${keyPool.usageInUsd(true)} (${Math.round( + keyPool.remainingQuota(true) * 100 + )}% remaining)`; + } else { + quota.turbo = `${Math.round(keyPool.remainingQuota() * 100)}%`; + quota.gpt4 = `${Math.round(keyPool.remainingQuota(true) * 100)}%`; + } + + if (!hasGpt4) { + delete quota.gpt4; + } + keyInfo = { ...keyInfo, - trial: keys.filter((k) => k.isTrial).length, - gpt4: keys.filter((k) => k.isGpt4).length, - quotaLeft: { - all: `${Math.round(keyPool.remainingQuota() * 100)}%`, - ...(hasGpt4 - ? { gpt4: `${Math.round(keyPool.remainingQuota(true) * 100)}%` } - : {}), + trial: trialKeys.length, + active: { + turbo: turboKeys.length, + ...(hasGpt4 ? { gpt4: gpt4Keys.length } : {}), }, + ...(config.quotaDisplayMode !== "none" ? { quota: quota } : {}), }; } const info = { uptime: process.uptime(), - timestamp: Date.now(), endpoints: { kobold: host, openai: host + "/proxy/openai", }, proompts: keys.reduce((acc, k) => acc + k.promptCount, 0), ...(config.modelRateLimit ? { proomptersNow: getUniqueIps() } : {}), - keyInfo, + ...getQueueInformation(), + keys: keyInfo, config: listConfig(), commitSha: process.env.COMMIT_SHA || "dev", }; - + const title = process.env.SPACE_ID ? `${process.env.SPACE_AUTHOR_NAME} / ${process.env.SPACE_TITLE}` : "OAI Reverse Proxy"; + const headerHtml = buildInfoPageHeader(new showdown.Converter(), title); const pageBody = ` @@ -66,29 +95,30 @@ function getInfoPageHtml(host: string) { ${title}

Service Info

${JSON.stringify(info, null, 2)}
`; + infoPageHtml = pageBody; + infoPageLastUpdated = Date.now(); + return pageBody; } -const infoPageHeaderHtml = buildInfoPageHeader(new showdown.Converter()); - /** * If the server operator provides a `greeting.md` file, it will be included in * the rendered info page. **/ -function buildInfoPageHeader(converter: showdown.Converter) { - const genericInfoPage = fs.readFileSync("info-page.md", "utf8"); +function buildInfoPageHeader(converter: showdown.Converter, title: string) { const customGreeting = fs.existsSync("greeting.md") ? fs.readFileSync("greeting.md", "utf8") : null; - let infoBody = genericInfoPage; + let infoBody = ` +# ${title}`; if (config.promptLogging) { infoBody += `\n## Prompt logging is enabled! The server operator has enabled prompt logging. The prompts you send to this proxy and the AI responses you receive may be saved. @@ -97,9 +127,32 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t **If you are uncomfortable with this, don't send prompts to this proxy!**`; } + + if (config.queueMode !== "none") { + infoBody += `\n### Queueing is enabled +Requests are queued to mitigate the effects of OpenAI's rate limits. If the AI is busy, your prompt will be queued and processed when a slot is available. + +You can check wait times below. **Be sure to enable streaming in your client, or your request will likely time out.**`; + } + if (customGreeting) { infoBody += `\n## Server Greeting\n ${customGreeting}`; } return converter.makeHtml(infoBody); } + +function getQueueInformation() { + if (config.queueMode === "none") { + return {}; + } + const waitMs = getAverageWaitTime(); + const waitTime = + waitMs < 60000 + ? `${Math.round(waitMs / 1000)} seconds` + : `${Math.round(waitMs / 60000)} minutes`; + return { + proomptersWaiting: getQueueLength(), + estimatedWaitTime: waitMs > 3000 ? waitTime : "no wait", + }; +} diff --git a/src/key-management/key-checker.ts b/src/key-management/key-checker.ts index ef44eb0..d8d086f 100644 --- a/src/key-management/key-checker.ts +++ b/src/key-management/key-checker.ts @@ -191,9 +191,6 @@ export class KeyChecker { return data; } - // TODO: This endpoint seems to be very delayed. I think we will need to track - // the time it last changed and estimate token usage ourselves in between - // changes by inspecting request payloads for prompt and completion tokens. private async getUsage(key: Key) { const querystring = KeyChecker.getUsageQuerystring(key.isTrial); const url = `${GET_USAGE_URL}?${querystring}`; diff --git a/src/key-management/key-pool.ts b/src/key-management/key-pool.ts index 129a9a5..163ef0d 100644 --- a/src/key-management/key-pool.ts +++ b/src/key-management/key-pool.ts @@ -2,28 +2,26 @@ round-robin access to keys. Keys are stored in the OPENAI_KEY environment variable as a comma-separated list of keys. */ import crypto from "crypto"; +import fs from "fs"; +import http from "http"; +import path from "path"; import { config } from "../config"; import { logger } from "../logger"; import { KeyChecker } from "./key-checker"; -// TODO: Made too many assumptions about OpenAI being the only provider and now +// TODO: Made too many assumptions about OpenAI being the only provider and now // this doesn't really work for Anthropic. Create a Provider interface and // implement Pool, Checker, and Models for each provider. export type Model = OpenAIModel | AnthropicModel; -export type OpenAIModel = -| "gpt-3.5-turbo" -| "gpt-4" -export type AnthropicModel = -| "claude-v1" -| "claude-instant-v1" +export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4"; +export type AnthropicModel = "claude-v1" | "claude-instant-v1"; export const SUPPORTED_MODELS: readonly Model[] = [ "gpt-3.5-turbo", "gpt-4", "claude-v1", "claude-instant-v1", ] as const; - export type Key = { /** The OpenAI API key itself. */ @@ -50,6 +48,30 @@ export type Key = { lastChecked: number; /** Key hash for displaying usage in the dashboard. */ hash: string; + /** The time at which this key was last rate limited. */ + rateLimitedAt: number; + /** + * Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a + * number. + * Formatted as a `\d+(m|s)` string denoting the time until the limit resets. + * Specifically, it seems to indicate the time until the key's quota will be + * fully restored; the key may be usable before this time as the limit is a + * rolling window. + * + * Requests which return a 429 do not count against the quota. + * + * Requests which fail for other reasons (e.g. 401) count against the quota. + */ + rateLimitRequestsReset: number; + /** + * Last known X-RateLimit-Tokens-Reset header from OpenAI, converted to a + * number. + * Appears to follow the same format as `rateLimitRequestsReset`. + * + * Requests which fail do not count against the quota as they do not consume + * tokens. + */ + rateLimitTokensReset: number; }; export type KeyUpdate = Omit< @@ -84,6 +106,9 @@ export class KeyPool { lastChecked: 0, promptCount: 0, hash: crypto.createHash("sha256").update(k).digest("hex").slice(0, 8), + rateLimitedAt: 0, + rateLimitRequestsReset: 0, + rateLimitTokensReset: 0, }; this.keys.push(newKey); @@ -113,9 +138,9 @@ export class KeyPool { public get(model: Model) { const needGpt4 = model.startsWith("gpt-4"); - const availableKeys = this.keys - .filter((key) => !key.isDisabled && (!needGpt4 || key.isGpt4)) - .sort((a, b) => a.lastUsed - b.lastUsed); + const availableKeys = this.keys.filter( + (key) => !key.isDisabled && (!needGpt4 || key.isGpt4) + ); if (availableKeys.length === 0) { let message = "No keys available. Please add more keys."; if (needGpt4) { @@ -125,26 +150,52 @@ export class KeyPool { throw new Error(message); } - // Prioritize trial keys - const trialKeys = availableKeys.filter((key) => key.isTrial); - if (trialKeys.length > 0) { - trialKeys[0].lastUsed = Date.now(); - return trialKeys[0]; - } + // Select a key, from highest priority to lowest priority: + // 1. Keys which are not rate limited + // a. We can assume any rate limits over a minute ago are expired + // b. If all keys were rate limited in the last minute, select the + // least recently rate limited key + // 2. Keys which are trials + // 3. Keys which have not been used in the longest time - // Otherwise, return the oldest key - const oldestKey = availableKeys[0]; - oldestKey.lastUsed = Date.now(); - return { ...oldestKey }; + const now = Date.now(); + const rateLimitThreshold = 60 * 1000; + + const keysByPriority = availableKeys.sort((a, b) => { + const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold; + const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold; + + if (aRateLimited && !bRateLimited) return 1; + if (!aRateLimited && bRateLimited) return -1; + if (aRateLimited && bRateLimited) { + return a.rateLimitedAt - b.rateLimitedAt; + } + + if (a.isTrial && !b.isTrial) return -1; + if (!a.isTrial && b.isTrial) return 1; + + return a.lastUsed - b.lastUsed; + }); + + const selectedKey = keysByPriority[0]; + selectedKey.lastUsed = Date.now(); + + // When a key is selected, we rate-limit it for a brief period of time to + // prevent the queue processor from immediately flooding it with requests + // while the initial request is still being processed (which is when we will + // get new rate limit headers). + // Instead, we will let a request through every second until the key + // becomes fully saturated and locked out again. + selectedKey.rateLimitedAt = Date.now(); + selectedKey.rateLimitRequestsReset = 1000; + return { ...selectedKey }; } /** Called by the key checker to update key information. */ public update(keyHash: string, update: KeyUpdate) { const keyFromPool = this.keys.find((k) => k.hash === keyHash)!; Object.assign(keyFromPool, { ...update, lastChecked: Date.now() }); - // if (update.usage && keyFromPool.usage >= keyFromPool.hardLimit) { - // this.disable(keyFromPool); - // } + // this.writeKeyStatus(); } public disable(key: Key) { @@ -165,22 +216,104 @@ export class KeyPool { return config.checkKeys && this.keys.some((key) => !key.lastChecked); } + /** + * Given a model, returns the period until a key will be available to service + * the request, or returns 0 if a key is ready immediately. + */ + public getLockoutPeriod(model: Model = "gpt-4"): number { + const needGpt4 = model.startsWith("gpt-4"); + const activeKeys = this.keys.filter( + (key) => !key.isDisabled && (!needGpt4 || key.isGpt4) + ); + + if (activeKeys.length === 0) { + // If there are no active keys for this model we can't fulfill requests. + // We'll return 0 to let the request through and return an error, + // otherwise the request will be stuck in the queue forever. + return 0; + } + + // A key is rate-limited if its `rateLimitedAt` plus the greater of its + // `rateLimitRequestsReset` and `rateLimitTokensReset` is after the + // current time. + + // If there are any keys that are not rate-limited, we can fulfill requests. + const now = Date.now(); + const rateLimitedKeys = activeKeys.filter((key) => { + const resetTime = Math.max( + key.rateLimitRequestsReset, + key.rateLimitTokensReset + ); + return now < key.rateLimitedAt + resetTime; + }).length; + const anyNotRateLimited = rateLimitedKeys < activeKeys.length; + + if (anyNotRateLimited) { + return 0; + } + + // If all keys are rate-limited, return the time until the first key is + // ready. + const timeUntilFirstReady = Math.min( + ...activeKeys.map((key) => { + const resetTime = Math.max( + key.rateLimitRequestsReset, + key.rateLimitTokensReset + ); + return key.rateLimitedAt + resetTime - now; + }) + ); + return timeUntilFirstReady; + } + + public markRateLimited(keyHash: string) { + this.log.warn({ key: keyHash }, "Key rate limited"); + const key = this.keys.find((k) => k.hash === keyHash)!; + key.rateLimitedAt = Date.now(); + } + public incrementPrompt(keyHash?: string) { if (!keyHash) return; const key = this.keys.find((k) => k.hash === keyHash)!; key.promptCount++; } - public downgradeKey(keyHash?: string) { - if (!keyHash) return; - this.log.warn({ key: keyHash }, "Downgrading key to GPT-3.5."); + public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) { const key = this.keys.find((k) => k.hash === keyHash)!; - key.isGpt4 = false; + const requestsReset = headers["x-ratelimit-reset-requests"]; + const tokensReset = headers["x-ratelimit-reset-tokens"]; + + // Sometimes OpenAI only sends one of the two rate limit headers, it's + // unclear why. + + if (requestsReset && typeof requestsReset === "string") { + this.log.info( + { key: key.hash, requestsReset }, + `Updating rate limit requests reset time` + ); + key.rateLimitRequestsReset = getResetDurationMillis(requestsReset); + } + + if (tokensReset && typeof tokensReset === "string") { + this.log.info( + { key: key.hash, tokensReset }, + `Updating rate limit tokens reset time` + ); + key.rateLimitTokensReset = getResetDurationMillis(tokensReset); + } + + if (!requestsReset && !tokensReset) { + this.log.warn( + { key: key.hash }, + `No rate limit headers in OpenAI response; skipping update` + ); + return; + } } /** Returns the remaining aggregate quota for all keys as a percentage. */ - public remainingQuota(gpt4Only = false) { - const keys = this.keys.filter((k) => !gpt4Only || k.isGpt4); + public remainingQuota(gpt4 = false) { + const keys = this.keys.filter((k) => k.isGpt4 === gpt4); if (keys.length === 0) return 0; const totalUsage = keys.reduce((acc, key) => { @@ -191,4 +324,55 @@ export class KeyPool { return 1 - totalUsage / totalLimit; } + + /** Returns used and available usage in USD. */ + public usageInUsd(gpt4 = false) { + const keys = this.keys.filter((k) => k.isGpt4 === gpt4); + if (keys.length === 0) return "???"; + + const totalHardLimit = keys.reduce( + (acc, { hardLimit }) => acc + hardLimit, + 0 + ); + const totalUsage = keys.reduce((acc, key) => { + // Keys can slightly exceed their quota + return acc + Math.min(key.usage, key.hardLimit); + }, 0); + + return `$${totalUsage.toFixed(2)} / $${totalHardLimit.toFixed(2)}`; + } + + /** Writes key status to disk. */ + // public writeKeyStatus() { + // const keys = this.keys.map((key) => ({ + // key: key.key, + // isGpt4: key.isGpt4, + // usage: key.usage, + // hardLimit: key.hardLimit, + // isDisabled: key.isDisabled, + // })); + // fs.writeFileSync( + // path.join(__dirname, "..", "keys.json"), + // JSON.stringify(keys, null, 2) + // ); + // } +} + + + +/** + * Converts reset string ("21.0032s" or "21ms") to a number of milliseconds. + * Result is clamped to 10s even though the API returns up to 60s, because the + * API returns the time until the entire quota is reset, even if a key may be + * able to fulfill requests before then due to partial resets. + **/ +function getResetDurationMillis(resetDuration?: string): number { + const match = resetDuration?.match(/(\d+(\.\d+)?)(s|ms)/); + if (match) { + const [, time, , unit] = match; + const value = parseFloat(time); + const result = unit === "s" ? value * 1000 : value; + return Math.min(result, 10000); + } + return 0; } diff --git a/src/proxy/middleware/request/check-streaming.ts b/src/proxy/middleware/request/check-streaming.ts index 858d4f8..9d30bb9 100644 --- a/src/proxy/middleware/request/check-streaming.ts +++ b/src/proxy/middleware/request/check-streaming.ts @@ -1,4 +1,3 @@ -import { config } from "../../../config"; import { ExpressHttpProxyReqCallback, isCompletionRequest } from "."; /** @@ -19,7 +18,7 @@ export const checkStreaming: ExpressHttpProxyReqCallback = (_proxyReq, req) => { req.body.stream = false; return; } - req.body.stream = config.allowStreaming; - req.isStreaming = config.allowStreaming; + req.body.stream = true; + req.isStreaming = true; } }; diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index c08c335..290b5d9 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -40,10 +40,17 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( { api: req.api, key: req.key?.hash }, `Starting to proxy SSE stream.` ); - res.setHeader("Content-Type", "text/event-stream"); - res.setHeader("Cache-Control", "no-cache"); - res.setHeader("Connection", "keep-alive"); - copyHeaders(proxyRes, res); + + // Queued streaming requests will already have a connection open and headers + // sent due to the heartbeat handler. In that case we can just start + // streaming the response without sending headers. + if (!res.headersSent) { + res.setHeader("Content-Type", "text/event-stream"); + res.setHeader("Cache-Control", "no-cache"); + res.setHeader("Connection", "keep-alive"); + copyHeaders(proxyRes, res); + res.flushHeaders(); + } const chunks: Buffer[] = []; proxyRes.on("data", (chunk) => { @@ -65,6 +72,24 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( { error: err, api: req.api, key: req.key?.hash }, `Error while streaming response.` ); + // OAI's spec doesn't allow for error events and clients wouldn't know + // what to do with them anyway, so we'll just send a completion event + // with the error message. + const fakeErrorEvent = { + id: "chatcmpl-error", + object: "chat.completion.chunk", + created: Date.now(), + model: "", + choices: [ + { + delta: { content: "[Proxy streaming error: " + err.message + "]" }, + index: 0, + finish_reason: "error", + }, + ], + }; + res.write(`data: ${JSON.stringify(fakeErrorEvent)}\n\n`); + res.write("data: [DONE]\n\n"); res.end(); reject(err); }); diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index 08c4115..eb67a5e 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -3,10 +3,12 @@ import * as http from "http"; import util from "util"; import zlib from "zlib"; import * as httpProxy from "http-proxy"; +import { config } from "../../../config"; import { logger } from "../../../logger"; import { keyPool } from "../../../key-management"; -import { logPrompt } from "./log-prompt"; +import { buildFakeSseMessage, enqueue, trackWaitTime } from "../../queue"; import { handleStreamedResponse } from "./handle-streamed-response"; +import { logPrompt } from "./log-prompt"; export const QUOTA_ROUTES = ["/v1/chat/completions"]; const DECODER_MAP = { @@ -21,6 +23,13 @@ const isSupportedContentEncoding = ( return contentEncoding in DECODER_MAP; }; +class RetryableError extends Error { + constructor(message: string) { + super(message); + this.name = "RetryableError"; + } +} + /** * Either decodes or streams the entire response body and then passes it as the * last argument to the rest of the middleware stack. @@ -54,7 +63,7 @@ export type ProxyResMiddleware = ProxyResHandlerWithBody[]; * the client. Once the stream is closed, the finalized body will be attached * to res.body and the remaining middleware will execute. */ -export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => { +export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => { return async ( proxyRes: http.IncomingMessage, req: Request, @@ -66,43 +75,23 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => { let lastMiddlewareName = initialHandler.name; - req.log.debug( - { - api: req.api, - route: req.path, - method: req.method, - stream: req.isStreaming, - middleware: lastMiddlewareName, - }, - "Handling proxy response" - ); - try { const body = await initialHandler(proxyRes, req, res); const middlewareStack: ProxyResMiddleware = []; if (req.isStreaming) { - // Anything that touches the response will break streaming requests so - // certain middleware can't be used. This includes whatever API-specific - // middleware is passed in, which isn't ideal but it's what we've got - // for now. - // Streamed requests will be treated as non-streaming if the upstream - // service returns a non-200 status code, so no need to include the - // error handler here. - - // This is a little too easy to accidentally screw up so I need to add a - // better way to differentiate between middleware that can be used for - // streaming requests and those that can't. Probably a separate type - // or function signature for streaming-compatible middleware. - middlewareStack.push(incrementKeyUsage, logPrompt); + // `handleStreamedResponse` writes to the response and ends it, so + // we can only execute middleware that doesn't write to the response. + middlewareStack.push(trackRateLimit, incrementKeyUsage, logPrompt); } else { middlewareStack.push( + trackRateLimit, handleUpstreamErrors, incrementKeyUsage, copyHttpHeaders, logPrompt, - ...middleware + ...apiMiddleware ); } @@ -110,25 +99,31 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => { lastMiddlewareName = middleware.name; await middleware(proxyRes, req, res, body); } + + trackWaitTime(req); } catch (error: any) { - if (res.headersSent) { - req.log.error( - `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})` - ); - // Either the upstream error handler got to it first, or we're mid- - // stream and we can't do anything about it. + // Hack: if the error is a retryable rate-limit error, the request has + // been re-enqueued and we can just return without doing anything else. + if (error instanceof RetryableError) { return; } + const errorData = { + error: error.stack, + thrownBy: lastMiddlewareName, + key: req.key?.hash, + }; const message = `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`; - logger.error( - { - error: error.stack, - thrownBy: lastMiddlewareName, - key: req.key?.hash, - }, - message - ); + if (res.headersSent) { + req.log.error(errorData, message); + // This should have already been handled by the error handler, but + // just in case... + if (!res.writableEnded) { + res.end(); + } + return; + } + logger.error(errorData, message); res .status(500) .json({ error: "Internal server error", proxy_note: message }); @@ -136,6 +131,15 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => { }; }; +function reenqueueRequest(req: Request) { + req.log.info( + { key: req.key?.hash, retryCount: req.retryCount }, + `Re-enqueueing request due to rate-limit error` + ); + req.retryCount++; + enqueue(req); +} + /** * Handles the response from the upstream service and decodes the body if * necessary. If the response is JSON, it will be parsed and returned as an @@ -160,8 +164,8 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( proxyRes.on("data", (chunk) => chunks.push(chunk)); proxyRes.on("end", async () => { let body = Buffer.concat(chunks); - const contentEncoding = proxyRes.headers["content-encoding"]; + const contentEncoding = proxyRes.headers["content-encoding"]; if (contentEncoding) { if (isSupportedContentEncoding(contentEncoding)) { const decoder = DECODER_MAP[contentEncoding]; @@ -169,7 +173,10 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( } else { const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`; logger.warn({ contentEncoding, key: req.key?.hash }, errorMessage); - res.status(500).json({ error: errorMessage, contentEncoding }); + writeErrorResponse(res, 500, { + error: errorMessage, + contentEncoding, + }); return reject(errorMessage); } } @@ -183,7 +190,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( } catch (error: any) { const errorMessage = `Proxy received response with invalid JSON: ${error.message}`; logger.warn({ error, key: req.key?.hash }, errorMessage); - res.status(500).json({ error: errorMessage }); + writeErrorResponse(res, 500, { error: errorMessage }); return reject(errorMessage); } }); @@ -197,7 +204,9 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( * Handles non-2xx responses from the upstream service. If the proxied response * is an error, this will respond to the client with an error payload and throw * an error to stop the middleware stack. - * @throws {Error} HTTP error status code from upstream service + * On 429 errors, if request queueing is enabled, the request will be silently + * re-enqueued. Otherwise, the request will be rejected with an error payload. + * @throws {Error} On HTTP error status code from upstream service */ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( proxyRes, @@ -206,6 +215,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( body ) => { const statusCode = proxyRes.statusCode || 500; + if (statusCode < 400) { return; } @@ -222,7 +232,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( if (typeof body === "object") { errorPayload = body; } else { - throw new Error("Received non-JSON error response from upstream."); + throw new Error("Received unparsable error response from upstream."); } } catch (parseError: any) { const statusMessage = proxyRes.statusMessage || "Unknown error"; @@ -238,8 +248,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( error: parseError.message, proxy_note: `This is likely a temporary error with the upstream service.`, }; - - res.status(statusCode).json(errorObject); + writeErrorResponse(res, statusCode, errorObject); throw new Error(parseError.message); } @@ -261,30 +270,35 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( keyPool.disable(req.key!); errorPayload.proxy_note = `The OpenAI key is invalid or revoked. ${tryAgainMessage}`; } else if (statusCode === 429) { - // One of: - // - Quota exceeded (key is dead, disable it) - // - Rate limit exceeded (key is fine, just try again) - // - Model overloaded (their fault, just try again) - if (errorPayload.error?.type === "insufficient_quota") { + const type = errorPayload.error?.type; + if (type === "insufficient_quota") { + // Billing quota exceeded (key is dead, disable it) keyPool.disable(req.key!); errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`; - } else if (errorPayload.error?.type === "billing_not_active") { + } else if (type === "billing_not_active") { + // Billing is not active (key is dead, disable it) keyPool.disable(req.key!); errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`; + } else if (type === "requests" || type === "tokens") { + // Per-minute request or token rate limit is exceeded, which we can retry + keyPool.markRateLimited(req.key!.hash); + if (config.queueMode !== "none") { + reenqueueRequest(req); + // TODO: I don't like using an error to control flow here + throw new RetryableError("Rate-limited request re-enqueued."); + } + errorPayload.proxy_note = `Assigned key's '${type}' rate limit has been exceeded. Try again later.`; } else { + // OpenAI probably overloaded errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`; } } else if (statusCode === 404) { // Most likely model not found + // TODO: this probably doesn't handle GPT-4-32k variants properly if the + // proxy has keys for both the 8k and 32k context models at the same time. if (errorPayload.error?.code === "model_not_found") { if (req.key!.isGpt4) { - // Malicious users can request a model that `startsWith` gpt-4 but is - // not actually a valid model name and force the key to be downgraded. - // I don't feel like fixing this so I'm just going to disable the key - // downgrading feature for now. - // keyPool.downgradeKey(req.key?.hash); - // errorPayload.proxy_note = `This key was incorrectly assigned to GPT-4. It has been downgraded to Turbo.`; - errorPayload.proxy_note = `This key was incorrectly flagged as GPT-4, or you requested a GPT-4 snapshot for which this key is not authorized. Try again to get a different key, or use Turbo.`; + errorPayload.proxy_note = `Assigned key isn't provisioned for the GPT-4 snapshot you requested. Try again to get a different key, or use Turbo.`; } else { errorPayload.proxy_note = `No model was found for this key.`; } @@ -301,10 +315,33 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( ); } - res.status(statusCode).json(errorPayload); + writeErrorResponse(res, statusCode, errorPayload); throw new Error(errorPayload.error?.message); }; +function writeErrorResponse( + res: Response, + statusCode: number, + errorPayload: Record +) { + // If we're mid-SSE stream, send a data event with the error payload and end + // the stream. Otherwise just send a normal error response. + if ( + res.headersSent || + res.getHeader("content-type") === "text/event-stream" + ) { + const msg = buildFakeSseMessage( + `upstream error (${statusCode})`, + JSON.stringify(errorPayload, null, 2) + ); + res.write(msg); + res.write(`data: [DONE]\n\n`); + res.end(); + } else { + res.status(statusCode).json(errorPayload); + } +} + /** Handles errors in rewriter pipelines. */ export const handleInternalError: httpProxy.ErrorCallback = ( err, @@ -313,19 +350,14 @@ export const handleInternalError: httpProxy.ErrorCallback = ( ) => { logger.error({ error: err }, "Error in http-proxy-middleware pipeline."); try { - if ("setHeader" in res && !res.headersSent) { - res.writeHead(500, { "Content-Type": "application/json" }); - } - res.end( - JSON.stringify({ - error: { - type: "proxy_error", - message: err.message, - stack: err.stack, - proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`, - }, - }) - ); + writeErrorResponse(res as Response, 500, { + error: { + type: "proxy_error", + message: err.message, + stack: err.stack, + proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`, + }, + }); } catch (e) { logger.error( { error: e }, @@ -340,6 +372,10 @@ const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => { } }; +const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => { + keyPool.updateRateLimits(req.key!.hash, proxyRes.headers); +}; + const copyHttpHeaders: ProxyResHandlerWithBody = async ( proxyRes, _req, diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 660b36c..780a56f 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -3,6 +3,7 @@ import * as http from "http"; import { createProxyMiddleware } from "http-proxy-middleware"; import { config } from "../config"; import { logger } from "../logger"; +import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit"; import { addKey, @@ -72,6 +73,7 @@ const openaiProxy = createProxyMiddleware({ selfHandleResponse: true, logger, }); +const queuedOpenaiProxy = createQueueMiddleware(openaiProxy); const openaiRouter = Router(); // Some clients don't include the /v1/ prefix in their requests and users get @@ -84,7 +86,7 @@ openaiRouter.use((req, _res, next) => { next(); }); openaiRouter.get("/v1/models", openaiProxy); -openaiRouter.post("/v1/chat/completions", ipLimiter, openaiProxy); +openaiRouter.post("/v1/chat/completions", ipLimiter, queuedOpenaiProxy); // If a browser tries to visit a route that doesn't exist, redirect to the info // page to help them find the right URL. openaiRouter.get("*", (req, res, next) => { diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts new file mode 100644 index 0000000..bdee281 --- /dev/null +++ b/src/proxy/queue.ts @@ -0,0 +1,367 @@ +/** + * Very scuffed request queue. OpenAI's GPT-4 keys have a very strict rate limit + * of 40000 generated tokens per minute. We don't actually know how many tokens + * a given key has generated, so our queue will simply retry requests that fail + * with a non-billing related 429 over and over again until they succeed. + * + * Dequeueing can operate in one of two modes: + * - 'fair': requests are dequeued in the order they were enqueued. + * - 'random': requests are dequeued randomly, not really a queue at all. + * + * When a request to a proxied endpoint is received, we create a closure around + * the call to http-proxy-middleware and attach it to the request. This allows + * us to pause the request until we have a key available. Further, if the + * proxied request encounters a retryable error, we can simply put the request + * back in the queue and it will be retried later using the same closure. + */ + +import type { Handler, Request } from "express"; +import { config, DequeueMode } from "../config"; +import { keyPool } from "../key-management"; +import { logger } from "../logger"; +import { AGNAI_DOT_CHAT_IP } from "./rate-limit"; + +const queue: Request[] = []; +const log = logger.child({ module: "request-queue" }); + +let dequeueMode: DequeueMode = "fair"; + +/** Maximum number of queue slots for Agnai.chat requests. */ +const AGNAI_CONCURRENCY_LIMIT = 15; +/** Maximum number of queue slots for individual users. */ +const USER_CONCURRENCY_LIMIT = 1; + +export function enqueue(req: Request) { + // All agnai.chat requests come from the same IP, so we allow them to have + // more spots in the queue. Can't make it unlimited because people will + // intentionally abuse it. + const maxConcurrentQueuedRequests = + req.ip === AGNAI_DOT_CHAT_IP + ? AGNAI_CONCURRENCY_LIMIT + : USER_CONCURRENCY_LIMIT; + const reqCount = queue.filter((r) => r.ip === req.ip).length; + if (reqCount >= maxConcurrentQueuedRequests) { + if (req.ip === AGNAI_DOT_CHAT_IP) { + // Re-enqueued requests are not counted towards the limit since they + // already made it through the queue once. + if (req.retryCount === 0) { + throw new Error("Too many agnai.chat requests are already queued"); + } + } else { + throw new Error("Request is already queued for this IP"); + } + } + + queue.push(req); + req.queueOutTime = 0; + + // shitty hack to remove hpm's event listeners on retried requests + removeProxyMiddlewareEventListeners(req); + + // If the request opted into streaming, we need to register a heartbeat + // handler to keep the connection alive while it waits in the queue. We + // deregister the handler when the request is dequeued. + if (req.body.stream) { + const res = req.res!; + if (!res.headersSent) { + initStreaming(req); + } + req.heartbeatInterval = setInterval(() => { + if (process.env.NODE_ENV === "production") { + req.res!.write(": queue heartbeat\n\n"); + } else { + req.log.info(`Sending heartbeat to request in queue.`); + const avgWait = Math.round(getAverageWaitTime() / 1000); + const currentDuration = Math.round((Date.now() - req.startTime) / 1000); + const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`; + req.res!.write(buildFakeSseMessage("heartbeat", debugMsg)); + } + }, 10000); + } + + // Register a handler to remove the request from the queue if the connection + // is aborted or closed before it is dequeued. + const removeFromQueue = () => { + req.log.info(`Removing aborted request from queue.`); + const index = queue.indexOf(req); + if (index !== -1) { + queue.splice(index, 1); + } + if (req.heartbeatInterval) { + clearInterval(req.heartbeatInterval); + } + }; + req.onAborted = removeFromQueue; + req.res!.once("close", removeFromQueue); + + if (req.retryCount ?? 0 > 0) { + req.log.info({ retries: req.retryCount }, `Enqueued request for retry.`); + } else { + req.log.info(`Enqueued new request.`); + } +} + +export function dequeue(model: string): Request | undefined { + // TODO: This should be set by some middleware that checks the request body. + const modelQueue = + model === "gpt-4" + ? queue.filter((req) => req.body.model?.startsWith("gpt-4")) + : queue.filter((req) => !req.body.model?.startsWith("gpt-4")); + + if (modelQueue.length === 0) { + return undefined; + } + + let req: Request; + + if (dequeueMode === "fair") { + // Dequeue the request that has been waiting the longest + req = modelQueue.reduce((prev, curr) => + prev.startTime < curr.startTime ? prev : curr + ); + } else { + // Dequeue a random request + const index = Math.floor(Math.random() * modelQueue.length); + req = modelQueue[index]; + } + queue.splice(queue.indexOf(req), 1); + + if (req.onAborted) { + req.res!.off("close", req.onAborted); + req.onAborted = undefined; + } + + if (req.heartbeatInterval) { + clearInterval(req.heartbeatInterval); + } + + // Track the time leaving the queue now, but don't add it to the wait times + // yet because we don't know if the request will succeed or fail. We track + // the time now and not after the request succeeds because we don't want to + // include the model processing time. + req.queueOutTime = Date.now(); + return req; +} + +/** + * Naive way to keep the queue moving by continuously dequeuing requests. Not + * ideal because it limits throughput but we probably won't have enough traffic + * or keys for this to be a problem. If it does we can dequeue multiple + * per tick. + **/ +function processQueue() { + // This isn't completely correct, because a key can service multiple models. + // Currently if a key is locked out on one model it will also stop servicing + // the others, because we only track one rate limit per key. + const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4"); + const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo"); + + const reqs: (Request | undefined)[] = []; + if (gpt4Lockout === 0) { + reqs.push(dequeue("gpt-4")); + } + if (turboLockout === 0) { + reqs.push(dequeue("gpt-3.5-turbo")); + } + + reqs.filter(Boolean).forEach((req) => { + if (req?.proceed) { + req.log.info({ retries: req.retryCount }, `Dequeuing request.`); + req.proceed(); + } + }); + setTimeout(processQueue, 50); +} + +/** + * Kill stalled requests after 5 minutes, and remove tracked wait times after 2 + * minutes. + **/ +function cleanQueue() { + const now = Date.now(); + const oldRequests = queue.filter( + (req) => now - (req.startTime ?? now) > 5 * 60 * 1000 + ); + oldRequests.forEach((req) => { + req.log.info(`Removing request from queue after 5 minutes.`); + killQueuedRequest(req); + }); + + const index = waitTimes.findIndex( + (waitTime) => now - waitTime.end > 120 * 1000 + ); + const removed = waitTimes.splice(0, index + 1); + log.info( + { stalledRequests: oldRequests.length, prunedWaitTimes: removed.length }, + `Cleaning up request queue.` + ); + setTimeout(cleanQueue, 20 * 1000); +} + +export function start() { + processQueue(); + cleanQueue(); + log.info(`Started request queue.`); +} + +const waitTimes: { start: number; end: number }[] = []; + +/** Adds a successful request to the list of wait times. */ +export function trackWaitTime(req: Request) { + waitTimes.push({ + start: req.startTime!, + end: req.queueOutTime ?? Date.now(), + }); +} + +/** Returns average wait time in milliseconds. */ +export function getAverageWaitTime() { + if (waitTimes.length === 0) { + return 0; + } + + // Include requests that are still in the queue right now + const now = Date.now(); + const waitTimesWithCurrent = [ + ...waitTimes, + ...queue.map((req) => ({ + start: req.startTime!, + end: now, + })), + ]; + + return ( + waitTimesWithCurrent.reduce((acc, curr) => acc + curr.end - curr.start, 0) / + waitTimesWithCurrent.length + ); +} + +export function getQueueLength() { + return queue.length; +} + +export function createQueueMiddleware(proxyMiddleware: Handler): Handler { + return (req, res, next) => { + if (config.queueMode === "none") { + return proxyMiddleware(req, res, next); + } + + req.proceed = () => { + proxyMiddleware(req, res, next); + }; + + try { + enqueue(req); + } catch (err: any) { + req.res!.status(429).json({ + type: "proxy_error", + message: err.message, + stack: err.stack, + proxy_note: `Only one request per IP can be queued at a time. If you don't have another request queued, your IP may be in use by another user.`, + }); + } + }; +} + +function killQueuedRequest(req: Request) { + if (!req.res || req.res.writableEnded) { + req.log.warn(`Attempted to terminate request that has already ended.`); + return; + } + const res = req.res; + try { + const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`; + if (res.headersSent) { + const fakeErrorEvent = buildFakeSseMessage("proxy queue error", message); + res.write(fakeErrorEvent); + res.end(); + } else { + res.status(500).json({ error: message }); + } + } catch (e) { + req.log.error(e, `Error killing stalled request.`); + } +} + +function initStreaming(req: Request) { + req.log.info(`Initiating streaming for new queued request.`); + const res = req.res!; + res.statusCode = 200; + res.setHeader("Content-Type", "text/event-stream"); + res.setHeader("Cache-Control", "no-cache"); + res.setHeader("Connection", "keep-alive"); + res.flushHeaders(); + res.write("\n"); + res.write(": joining queue\n\n"); +} + +export function buildFakeSseMessage(type: string, string: string) { + const fakeEvent = { + id: "chatcmpl-" + type, + object: "chat.completion.chunk", + created: Date.now(), + model: "", + choices: [ + { + delta: { content: `[${type}: ${string}]\n` }, + index: 0, + finish_reason: type, + }, + ], + }; + return `data: ${JSON.stringify(fakeEvent)}\n\n`; +} + +/** + * http-proxy-middleware attaches a bunch of event listeners to the req and + * res objects which causes problems with our approach to re-enqueuing failed + * proxied requests. This function removes those event listeners. + * We don't have references to the original event listeners, so we have to + * look through the list and remove HPM's listeners by looking for particular + * strings in the listener functions. This is an astoundingly shitty way to do + * this, but it's the best I can come up with. + */ +function removeProxyMiddlewareEventListeners(req: Request) { + // node_modules/http-proxy-middleware/dist/plugins/default/debug-proxy-errors-plugin.js:29 + // res.listeners('close') + const RES_ONCLOSE = `Destroying proxyRes in proxyRes close event`; + // node_modules/http-proxy-middleware/dist/plugins/default/debug-proxy-errors-plugin.js:19 + // res.listeners('error') + const RES_ONERROR = `Socket error in proxyReq event`; + // node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js:146 + // req.listeners('aborted') + const REQ_ONABORTED = `proxyReq.abort()`; + // node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js:156 + // req.listeners('error') + const REQ_ONERROR = `if (req.socket.destroyed`; + + const res = req.res!; + + const resOnClose = res + .listeners("close") + .find((listener) => listener.toString().includes(RES_ONCLOSE)); + if (resOnClose) { + res.removeListener("close", resOnClose as any); + } + + const resOnError = res + .listeners("error") + .find((listener) => listener.toString().includes(RES_ONERROR)); + if (resOnError) { + res.removeListener("error", resOnError as any); + } + + const reqOnAborted = req + .listeners("aborted") + .find((listener) => listener.toString().includes(REQ_ONABORTED)); + if (reqOnAborted) { + req.removeListener("aborted", reqOnAborted as any); + } + + const reqOnError = req + .listeners("error") + .find((listener) => listener.toString().includes(REQ_ONERROR)); + if (reqOnError) { + req.removeListener("error", reqOnError as any); + } +} diff --git a/src/proxy/rate-limit.ts b/src/proxy/rate-limit.ts index fe92498..29dba3d 100644 --- a/src/proxy/rate-limit.ts +++ b/src/proxy/rate-limit.ts @@ -2,6 +2,7 @@ import { Request, Response, NextFunction } from "express"; import { config } from "../config"; import { logger } from "../logger"; +export const AGNAI_DOT_CHAT_IP = "157.230.249.32"; const RATE_LIMIT_ENABLED = Boolean(config.modelRateLimit); const RATE_LIMIT = Math.max(1, config.modelRateLimit); const ONE_MINUTE_MS = 60 * 1000; @@ -66,7 +67,7 @@ export const ipLimiter = (req: Request, res: Response, next: NextFunction) => { // Exempt Agnai.chat from rate limiting since it's shared between a lot of // users. Dunno how to prevent this from being abused without some sort of // identifier sent from Agnaistic to identify specific users. - if (req.ip === "157.230.249.32") { + if (req.ip === AGNAI_DOT_CHAT_IP) { next(); return; } diff --git a/src/server.ts b/src/server.ts index 135142e..86f8d8b 100644 --- a/src/server.ts +++ b/src/server.ts @@ -9,6 +9,7 @@ import { keyPool } from "./key-management"; import { proxyRouter, rewriteTavernRequests } from "./proxy/routes"; import { handleInfoPage } from "./info-page"; import { logQueue } from "./prompt-logging"; +import { start as startRequestQueue } from "./proxy/queue"; const PORT = config.port; @@ -17,6 +18,7 @@ const app = express(); app.use("/", rewriteTavernRequests); app.use( pinoHttp({ + quietReqLogger: true, logger, // SillyTavern spams the hell out of this endpoint so don't log it autoLogging: { ignore: (req) => req.url === "/proxy/kobold/api/v1/model" }, @@ -31,6 +33,11 @@ app.use( }, }) ); +app.use((req, _res, next) => { + req.startTime = Date.now(); + req.retryCount = 0; + next(); +}); app.use(cors()); app.use( express.json({ limit: "10mb" }), @@ -40,21 +47,31 @@ app.use( // deploy this somewhere without a load balancer then incoming requests can // spoof the X-Forwarded-For header and bypass the rate limiting. app.set("trust proxy", true); + // routes app.get("/", handleInfoPage); app.use("/proxy", proxyRouter); + // 500 and 404 app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => { if (err.status) { res.status(err.status).json({ error: err.message }); } else { logger.error(err); - res.status(500).json({ error: "Internal server error" }); + res.status(500).json({ + error: { + type: "proxy_error", + message: err.message, + stack: err.stack, + proxy_note: `Reverse proxy encountered an internal server error.`, + }, + }); } }); app.use((_req: unknown, res: express.Response) => { res.status(404).json({ error: "Not found" }); }); + // start server and load keys app.listen(PORT, async () => { try { @@ -108,4 +125,8 @@ app.listen(PORT, async () => { logger.info("Starting prompt logging..."); logQueue.start(); } + if (config.queueMode !== "none") { + logger.info("Starting request queue..."); + startRequestQueue(); + } }); diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index 76f133d..0471397 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -7,6 +7,12 @@ declare global { key?: Key; api: "kobold" | "openai" | "anthropic"; isStreaming?: boolean; + startTime: number; + retryCount: number; + queueOutTime?: number; + onAborted?: () => void; + proceed: () => void; + heartbeatInterval?: NodeJS.Timeout; } } }