Implements request queueing (khanon/oai-reverse-proxy!6)
This commit is contained in:
parent
e9e9f1f8b6
commit
e03f3d48dd
|
@ -9,6 +9,8 @@
|
|||
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
|
||||
# REJECT_SAMPLE_RATE=0.2
|
||||
# CHECK_KEYS=false
|
||||
# QUOTA_DISPLAY_MODE=full
|
||||
# QUEUE_MODE=fair
|
||||
|
||||
# Note: CHECK_KEYS is disabled by default in local development mode, but enabled
|
||||
# by default in production mode.
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
<!-- Don't remove this header, Showdown needs it to parse the file correctly -->
|
||||
|
||||
# OAI Reverse Proxy
|
|
@ -3,7 +3,8 @@ dotenv.config();
|
|||
|
||||
const isDev = process.env.NODE_ENV !== "production";
|
||||
|
||||
type PROMPT_LOGGING_BACKEND = "google_sheets";
|
||||
type PromptLoggingBackend = "google_sheets";
|
||||
export type DequeueMode = "fair" | "random" | "none";
|
||||
|
||||
type Config = {
|
||||
/** The port the proxy server will listen on. */
|
||||
|
@ -25,17 +26,29 @@ type Config = {
|
|||
/** Pino log level. */
|
||||
logLevel?: "debug" | "info" | "warn" | "error";
|
||||
/** Whether prompts and responses should be logged to persistent storage. */
|
||||
promptLogging?: boolean; // TODO: Implement prompt logging once we have persistent storage.
|
||||
promptLogging?: boolean;
|
||||
/** Which prompt logging backend to use. */
|
||||
promptLoggingBackend?: PROMPT_LOGGING_BACKEND;
|
||||
promptLoggingBackend?: PromptLoggingBackend;
|
||||
/** Base64-encoded Google Sheets API key. */
|
||||
googleSheetsKey?: string;
|
||||
/** Google Sheets spreadsheet ID. */
|
||||
googleSheetsSpreadsheetId?: string;
|
||||
/** Whether to periodically check keys for usage and validity. */
|
||||
checkKeys?: boolean;
|
||||
/** Whether to allow streaming completions. This is usually fine but can cause issues on some deployments. */
|
||||
allowStreaming?: boolean;
|
||||
/**
|
||||
* How to display quota information on the info page.
|
||||
* 'none' - Hide quota information
|
||||
* 'simple' - Display quota information as a percentage
|
||||
* 'full' - Display quota information as usage against total capacity
|
||||
*/
|
||||
quotaDisplayMode: "none" | "simple" | "full";
|
||||
/**
|
||||
* Which request queueing strategy to use when keys are over their rate limit.
|
||||
* 'fair' - Requests are serviced in the order they were received (default)
|
||||
* 'random' - Requests are serviced randomly
|
||||
* 'none' - Requests are not queued and users have to retry manually
|
||||
*/
|
||||
queueMode: DequeueMode;
|
||||
};
|
||||
|
||||
// To change configs, create a file called .env in the root directory.
|
||||
|
@ -54,6 +67,7 @@ export const config: Config = {
|
|||
),
|
||||
logLevel: getEnvWithDefault("LOG_LEVEL", "info"),
|
||||
checkKeys: getEnvWithDefault("CHECK_KEYS", !isDev),
|
||||
quotaDisplayMode: getEnvWithDefault("QUOTA_DISPLAY_MODE", "full"),
|
||||
promptLogging: getEnvWithDefault("PROMPT_LOGGING", false),
|
||||
promptLoggingBackend: getEnvWithDefault("PROMPT_LOGGING_BACKEND", undefined),
|
||||
googleSheetsKey: getEnvWithDefault("GOOGLE_SHEETS_KEY", undefined),
|
||||
|
@ -61,7 +75,7 @@ export const config: Config = {
|
|||
"GOOGLE_SHEETS_SPREADSHEET_ID",
|
||||
undefined
|
||||
),
|
||||
allowStreaming: getEnvWithDefault("ALLOW_STREAMING", true),
|
||||
queueMode: getEnvWithDefault("QUEUE_MODE", "fair"),
|
||||
} as const;
|
||||
|
||||
export const SENSITIVE_KEYS: (keyof Config)[] = [
|
||||
|
|
|
@ -4,60 +4,89 @@ import showdown from "showdown";
|
|||
import { config, listConfig } from "./config";
|
||||
import { keyPool } from "./key-management";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { getAverageWaitTime, getQueueLength } from "./proxy/queue";
|
||||
|
||||
const INFO_PAGE_TTL = 5000;
|
||||
let infoPageHtml: string | undefined;
|
||||
let infoPageLastUpdated = 0;
|
||||
|
||||
export const handleInfoPage = (req: Request, res: Response) => {
|
||||
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
|
||||
res.send(infoPageHtml);
|
||||
return;
|
||||
}
|
||||
|
||||
// Huggingface puts spaces behind some cloudflare ssl proxy, so `req.protocol` is `http` but the correct URL is actually `https`
|
||||
const host = req.get("host");
|
||||
const isHuggingface = host?.includes("hf.space");
|
||||
const protocol = isHuggingface ? "https" : req.protocol;
|
||||
res.send(getInfoPageHtml(protocol + "://" + host));
|
||||
res.send(cacheInfoPageHtml(protocol + "://" + host));
|
||||
};
|
||||
|
||||
function getInfoPageHtml(host: string) {
|
||||
function cacheInfoPageHtml(host: string) {
|
||||
const keys = keyPool.list();
|
||||
let keyInfo: Record<string, any> = {
|
||||
all: keys.length,
|
||||
active: keys.filter((k) => !k.isDisabled).length,
|
||||
};
|
||||
let keyInfo: Record<string, any> = { all: keys.length };
|
||||
|
||||
if (keyPool.anyUnchecked()) {
|
||||
const uncheckedKeys = keys.filter((k) => !k.lastChecked);
|
||||
keyInfo = {
|
||||
...keyInfo,
|
||||
active: keys.filter((k) => !k.isDisabled).length,
|
||||
status: `Still checking ${uncheckedKeys.length} keys...`,
|
||||
};
|
||||
} else if (config.checkKeys) {
|
||||
const trialKeys = keys.filter((k) => k.isTrial);
|
||||
const turboKeys = keys.filter((k) => !k.isGpt4 && !k.isDisabled);
|
||||
const gpt4Keys = keys.filter((k) => k.isGpt4 && !k.isDisabled);
|
||||
|
||||
const quota: Record<string, string> = { turbo: "", gpt4: "" };
|
||||
const hasGpt4 = keys.some((k) => k.isGpt4);
|
||||
|
||||
if (config.quotaDisplayMode === "full") {
|
||||
quota.turbo = `${keyPool.usageInUsd()} (${Math.round(
|
||||
keyPool.remainingQuota() * 100
|
||||
)}% remaining)`;
|
||||
quota.gpt4 = `${keyPool.usageInUsd(true)} (${Math.round(
|
||||
keyPool.remainingQuota(true) * 100
|
||||
)}% remaining)`;
|
||||
} else {
|
||||
quota.turbo = `${Math.round(keyPool.remainingQuota() * 100)}%`;
|
||||
quota.gpt4 = `${Math.round(keyPool.remainingQuota(true) * 100)}%`;
|
||||
}
|
||||
|
||||
if (!hasGpt4) {
|
||||
delete quota.gpt4;
|
||||
}
|
||||
|
||||
keyInfo = {
|
||||
...keyInfo,
|
||||
trial: keys.filter((k) => k.isTrial).length,
|
||||
gpt4: keys.filter((k) => k.isGpt4).length,
|
||||
quotaLeft: {
|
||||
all: `${Math.round(keyPool.remainingQuota() * 100)}%`,
|
||||
...(hasGpt4
|
||||
? { gpt4: `${Math.round(keyPool.remainingQuota(true) * 100)}%` }
|
||||
: {}),
|
||||
trial: trialKeys.length,
|
||||
active: {
|
||||
turbo: turboKeys.length,
|
||||
...(hasGpt4 ? { gpt4: gpt4Keys.length } : {}),
|
||||
},
|
||||
...(config.quotaDisplayMode !== "none" ? { quota: quota } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
const info = {
|
||||
uptime: process.uptime(),
|
||||
timestamp: Date.now(),
|
||||
endpoints: {
|
||||
kobold: host,
|
||||
openai: host + "/proxy/openai",
|
||||
},
|
||||
proompts: keys.reduce((acc, k) => acc + k.promptCount, 0),
|
||||
...(config.modelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
||||
keyInfo,
|
||||
...getQueueInformation(),
|
||||
keys: keyInfo,
|
||||
config: listConfig(),
|
||||
commitSha: process.env.COMMIT_SHA || "dev",
|
||||
};
|
||||
|
||||
|
||||
const title = process.env.SPACE_ID
|
||||
? `${process.env.SPACE_AUTHOR_NAME} / ${process.env.SPACE_TITLE}`
|
||||
: "OAI Reverse Proxy";
|
||||
const headerHtml = buildInfoPageHeader(new showdown.Converter(), title);
|
||||
|
||||
const pageBody = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
@ -66,29 +95,30 @@ function getInfoPageHtml(host: string) {
|
|||
<title>${title}</title>
|
||||
</head>
|
||||
<body style="font-family: sans-serif; background-color: #f0f0f0; padding: 1em;"
|
||||
${infoPageHeaderHtml}
|
||||
${headerHtml}
|
||||
<hr />
|
||||
<h2>Service Info</h2>
|
||||
<pre>${JSON.stringify(info, null, 2)}</pre>
|
||||
</body>
|
||||
</html>`;
|
||||
|
||||
infoPageHtml = pageBody;
|
||||
infoPageLastUpdated = Date.now();
|
||||
|
||||
return pageBody;
|
||||
}
|
||||
|
||||
const infoPageHeaderHtml = buildInfoPageHeader(new showdown.Converter());
|
||||
|
||||
/**
|
||||
* If the server operator provides a `greeting.md` file, it will be included in
|
||||
* the rendered info page.
|
||||
**/
|
||||
function buildInfoPageHeader(converter: showdown.Converter) {
|
||||
const genericInfoPage = fs.readFileSync("info-page.md", "utf8");
|
||||
function buildInfoPageHeader(converter: showdown.Converter, title: string) {
|
||||
const customGreeting = fs.existsSync("greeting.md")
|
||||
? fs.readFileSync("greeting.md", "utf8")
|
||||
: null;
|
||||
|
||||
let infoBody = genericInfoPage;
|
||||
let infoBody = `<!-- Header for Showdown's parser, don't remove this line -->
|
||||
# ${title}`;
|
||||
if (config.promptLogging) {
|
||||
infoBody += `\n## Prompt logging is enabled!
|
||||
The server operator has enabled prompt logging. The prompts you send to this proxy and the AI responses you receive may be saved.
|
||||
|
@ -97,9 +127,32 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t
|
|||
|
||||
**If you are uncomfortable with this, don't send prompts to this proxy!**`;
|
||||
}
|
||||
|
||||
if (config.queueMode !== "none") {
|
||||
infoBody += `\n### Queueing is enabled
|
||||
Requests are queued to mitigate the effects of OpenAI's rate limits. If the AI is busy, your prompt will be queued and processed when a slot is available.
|
||||
|
||||
You can check wait times below. **Be sure to enable streaming in your client, or your request will likely time out.**`;
|
||||
}
|
||||
|
||||
if (customGreeting) {
|
||||
infoBody += `\n## Server Greeting\n
|
||||
${customGreeting}`;
|
||||
}
|
||||
return converter.makeHtml(infoBody);
|
||||
}
|
||||
|
||||
function getQueueInformation() {
|
||||
if (config.queueMode === "none") {
|
||||
return {};
|
||||
}
|
||||
const waitMs = getAverageWaitTime();
|
||||
const waitTime =
|
||||
waitMs < 60000
|
||||
? `${Math.round(waitMs / 1000)} seconds`
|
||||
: `${Math.round(waitMs / 60000)} minutes`;
|
||||
return {
|
||||
proomptersWaiting: getQueueLength(),
|
||||
estimatedWaitTime: waitMs > 3000 ? waitTime : "no wait",
|
||||
};
|
||||
}
|
||||
|
|
|
@ -191,9 +191,6 @@ export class KeyChecker {
|
|||
return data;
|
||||
}
|
||||
|
||||
// TODO: This endpoint seems to be very delayed. I think we will need to track
|
||||
// the time it last changed and estimate token usage ourselves in between
|
||||
// changes by inspecting request payloads for prompt and completion tokens.
|
||||
private async getUsage(key: Key) {
|
||||
const querystring = KeyChecker.getUsageQuerystring(key.isTrial);
|
||||
const url = `${GET_USAGE_URL}?${querystring}`;
|
||||
|
|
|
@ -2,28 +2,26 @@
|
|||
round-robin access to keys. Keys are stored in the OPENAI_KEY environment
|
||||
variable as a comma-separated list of keys. */
|
||||
import crypto from "crypto";
|
||||
import fs from "fs";
|
||||
import http from "http";
|
||||
import path from "path";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { KeyChecker } from "./key-checker";
|
||||
|
||||
// TODO: Made too many assumptions about OpenAI being the only provider and now
|
||||
// TODO: Made too many assumptions about OpenAI being the only provider and now
|
||||
// this doesn't really work for Anthropic. Create a Provider interface and
|
||||
// implement Pool, Checker, and Models for each provider.
|
||||
|
||||
export type Model = OpenAIModel | AnthropicModel;
|
||||
export type OpenAIModel =
|
||||
| "gpt-3.5-turbo"
|
||||
| "gpt-4"
|
||||
export type AnthropicModel =
|
||||
| "claude-v1"
|
||||
| "claude-instant-v1"
|
||||
export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4";
|
||||
export type AnthropicModel = "claude-v1" | "claude-instant-v1";
|
||||
export const SUPPORTED_MODELS: readonly Model[] = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"claude-v1",
|
||||
"claude-instant-v1",
|
||||
] as const;
|
||||
|
||||
|
||||
export type Key = {
|
||||
/** The OpenAI API key itself. */
|
||||
|
@ -50,6 +48,30 @@ export type Key = {
|
|||
lastChecked: number;
|
||||
/** Key hash for displaying usage in the dashboard. */
|
||||
hash: string;
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/**
|
||||
* Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a
|
||||
* number.
|
||||
* Formatted as a `\d+(m|s)` string denoting the time until the limit resets.
|
||||
* Specifically, it seems to indicate the time until the key's quota will be
|
||||
* fully restored; the key may be usable before this time as the limit is a
|
||||
* rolling window.
|
||||
*
|
||||
* Requests which return a 429 do not count against the quota.
|
||||
*
|
||||
* Requests which fail for other reasons (e.g. 401) count against the quota.
|
||||
*/
|
||||
rateLimitRequestsReset: number;
|
||||
/**
|
||||
* Last known X-RateLimit-Tokens-Reset header from OpenAI, converted to a
|
||||
* number.
|
||||
* Appears to follow the same format as `rateLimitRequestsReset`.
|
||||
*
|
||||
* Requests which fail do not count against the quota as they do not consume
|
||||
* tokens.
|
||||
*/
|
||||
rateLimitTokensReset: number;
|
||||
};
|
||||
|
||||
export type KeyUpdate = Omit<
|
||||
|
@ -84,6 +106,9 @@ export class KeyPool {
|
|||
lastChecked: 0,
|
||||
promptCount: 0,
|
||||
hash: crypto.createHash("sha256").update(k).digest("hex").slice(0, 8),
|
||||
rateLimitedAt: 0,
|
||||
rateLimitRequestsReset: 0,
|
||||
rateLimitTokensReset: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
|
||||
|
@ -113,9 +138,9 @@ export class KeyPool {
|
|||
|
||||
public get(model: Model) {
|
||||
const needGpt4 = model.startsWith("gpt-4");
|
||||
const availableKeys = this.keys
|
||||
.filter((key) => !key.isDisabled && (!needGpt4 || key.isGpt4))
|
||||
.sort((a, b) => a.lastUsed - b.lastUsed);
|
||||
const availableKeys = this.keys.filter(
|
||||
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
|
||||
);
|
||||
if (availableKeys.length === 0) {
|
||||
let message = "No keys available. Please add more keys.";
|
||||
if (needGpt4) {
|
||||
|
@ -125,26 +150,52 @@ export class KeyPool {
|
|||
throw new Error(message);
|
||||
}
|
||||
|
||||
// Prioritize trial keys
|
||||
const trialKeys = availableKeys.filter((key) => key.isTrial);
|
||||
if (trialKeys.length > 0) {
|
||||
trialKeys[0].lastUsed = Date.now();
|
||||
return trialKeys[0];
|
||||
}
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. We can assume any rate limits over a minute ago are expired
|
||||
// b. If all keys were rate limited in the last minute, select the
|
||||
// least recently rate limited key
|
||||
// 2. Keys which are trials
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
// Otherwise, return the oldest key
|
||||
const oldestKey = availableKeys[0];
|
||||
oldestKey.lastUsed = Date.now();
|
||||
return { ...oldestKey };
|
||||
const now = Date.now();
|
||||
const rateLimitThreshold = 60 * 1000;
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold;
|
||||
const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
if (a.isTrial && !b.isTrial) return -1;
|
||||
if (!a.isTrial && b.isTrial) return 1;
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
|
||||
// When a key is selected, we rate-limit it for a brief period of time to
|
||||
// prevent the queue processor from immediately flooding it with requests
|
||||
// while the initial request is still being processed (which is when we will
|
||||
// get new rate limit headers).
|
||||
// Instead, we will let a request through every second until the key
|
||||
// becomes fully saturated and locked out again.
|
||||
selectedKey.rateLimitedAt = Date.now();
|
||||
selectedKey.rateLimitRequestsReset = 1000;
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
/** Called by the key checker to update key information. */
|
||||
public update(keyHash: string, update: KeyUpdate) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
|
||||
Object.assign(keyFromPool, { ...update, lastChecked: Date.now() });
|
||||
// if (update.usage && keyFromPool.usage >= keyFromPool.hardLimit) {
|
||||
// this.disable(keyFromPool);
|
||||
// }
|
||||
// this.writeKeyStatus();
|
||||
}
|
||||
|
||||
public disable(key: Key) {
|
||||
|
@ -165,22 +216,104 @@ export class KeyPool {
|
|||
return config.checkKeys && this.keys.some((key) => !key.lastChecked);
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a model, returns the period until a key will be available to service
|
||||
* the request, or returns 0 if a key is ready immediately.
|
||||
*/
|
||||
public getLockoutPeriod(model: Model = "gpt-4"): number {
|
||||
const needGpt4 = model.startsWith("gpt-4");
|
||||
const activeKeys = this.keys.filter(
|
||||
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
|
||||
);
|
||||
|
||||
if (activeKeys.length === 0) {
|
||||
// If there are no active keys for this model we can't fulfill requests.
|
||||
// We'll return 0 to let the request through and return an error,
|
||||
// otherwise the request will be stuck in the queue forever.
|
||||
return 0;
|
||||
}
|
||||
|
||||
// A key is rate-limited if its `rateLimitedAt` plus the greater of its
|
||||
// `rateLimitRequestsReset` and `rateLimitTokensReset` is after the
|
||||
// current time.
|
||||
|
||||
// If there are any keys that are not rate-limited, we can fulfill requests.
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((key) => {
|
||||
const resetTime = Math.max(
|
||||
key.rateLimitRequestsReset,
|
||||
key.rateLimitTokensReset
|
||||
);
|
||||
return now < key.rateLimitedAt + resetTime;
|
||||
}).length;
|
||||
const anyNotRateLimited = rateLimitedKeys < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
const timeUntilFirstReady = Math.min(
|
||||
...activeKeys.map((key) => {
|
||||
const resetTime = Math.max(
|
||||
key.rateLimitRequestsReset,
|
||||
key.rateLimitTokensReset
|
||||
);
|
||||
return key.rateLimitedAt + resetTime - now;
|
||||
})
|
||||
);
|
||||
return timeUntilFirstReady;
|
||||
}
|
||||
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.warn({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
key.rateLimitedAt = Date.now();
|
||||
}
|
||||
|
||||
public incrementPrompt(keyHash?: string) {
|
||||
if (!keyHash) return;
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
key.promptCount++;
|
||||
}
|
||||
|
||||
public downgradeKey(keyHash?: string) {
|
||||
if (!keyHash) return;
|
||||
this.log.warn({ key: keyHash }, "Downgrading key to GPT-3.5.");
|
||||
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) {
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
key.isGpt4 = false;
|
||||
const requestsReset = headers["x-ratelimit-reset-requests"];
|
||||
const tokensReset = headers["x-ratelimit-reset-tokens"];
|
||||
|
||||
// Sometimes OpenAI only sends one of the two rate limit headers, it's
|
||||
// unclear why.
|
||||
|
||||
if (requestsReset && typeof requestsReset === "string") {
|
||||
this.log.info(
|
||||
{ key: key.hash, requestsReset },
|
||||
`Updating rate limit requests reset time`
|
||||
);
|
||||
key.rateLimitRequestsReset = getResetDurationMillis(requestsReset);
|
||||
}
|
||||
|
||||
if (tokensReset && typeof tokensReset === "string") {
|
||||
this.log.info(
|
||||
{ key: key.hash, tokensReset },
|
||||
`Updating rate limit tokens reset time`
|
||||
);
|
||||
key.rateLimitTokensReset = getResetDurationMillis(tokensReset);
|
||||
}
|
||||
|
||||
if (!requestsReset && !tokensReset) {
|
||||
this.log.warn(
|
||||
{ key: key.hash },
|
||||
`No rate limit headers in OpenAI response; skipping update`
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns the remaining aggregate quota for all keys as a percentage. */
|
||||
public remainingQuota(gpt4Only = false) {
|
||||
const keys = this.keys.filter((k) => !gpt4Only || k.isGpt4);
|
||||
public remainingQuota(gpt4 = false) {
|
||||
const keys = this.keys.filter((k) => k.isGpt4 === gpt4);
|
||||
if (keys.length === 0) return 0;
|
||||
|
||||
const totalUsage = keys.reduce((acc, key) => {
|
||||
|
@ -191,4 +324,55 @@ export class KeyPool {
|
|||
|
||||
return 1 - totalUsage / totalLimit;
|
||||
}
|
||||
|
||||
/** Returns used and available usage in USD. */
|
||||
public usageInUsd(gpt4 = false) {
|
||||
const keys = this.keys.filter((k) => k.isGpt4 === gpt4);
|
||||
if (keys.length === 0) return "???";
|
||||
|
||||
const totalHardLimit = keys.reduce(
|
||||
(acc, { hardLimit }) => acc + hardLimit,
|
||||
0
|
||||
);
|
||||
const totalUsage = keys.reduce((acc, key) => {
|
||||
// Keys can slightly exceed their quota
|
||||
return acc + Math.min(key.usage, key.hardLimit);
|
||||
}, 0);
|
||||
|
||||
return `$${totalUsage.toFixed(2)} / $${totalHardLimit.toFixed(2)}`;
|
||||
}
|
||||
|
||||
/** Writes key status to disk. */
|
||||
// public writeKeyStatus() {
|
||||
// const keys = this.keys.map((key) => ({
|
||||
// key: key.key,
|
||||
// isGpt4: key.isGpt4,
|
||||
// usage: key.usage,
|
||||
// hardLimit: key.hardLimit,
|
||||
// isDisabled: key.isDisabled,
|
||||
// }));
|
||||
// fs.writeFileSync(
|
||||
// path.join(__dirname, "..", "keys.json"),
|
||||
// JSON.stringify(keys, null, 2)
|
||||
// );
|
||||
// }
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Converts reset string ("21.0032s" or "21ms") to a number of milliseconds.
|
||||
* Result is clamped to 10s even though the API returns up to 60s, because the
|
||||
* API returns the time until the entire quota is reset, even if a key may be
|
||||
* able to fulfill requests before then due to partial resets.
|
||||
**/
|
||||
function getResetDurationMillis(resetDuration?: string): number {
|
||||
const match = resetDuration?.match(/(\d+(\.\d+)?)(s|ms)/);
|
||||
if (match) {
|
||||
const [, time, , unit] = match;
|
||||
const value = parseFloat(time);
|
||||
const result = unit === "s" ? value * 1000 : value;
|
||||
return Math.min(result, 10000);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import { config } from "../../../config";
|
||||
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
|
||||
|
||||
/**
|
||||
|
@ -19,7 +18,7 @@ export const checkStreaming: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
|
|||
req.body.stream = false;
|
||||
return;
|
||||
}
|
||||
req.body.stream = config.allowStreaming;
|
||||
req.isStreaming = config.allowStreaming;
|
||||
req.body.stream = true;
|
||||
req.isStreaming = true;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -40,10 +40,17 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
{ api: req.api, key: req.key?.hash },
|
||||
`Starting to proxy SSE stream.`
|
||||
);
|
||||
res.setHeader("Content-Type", "text/event-stream");
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
copyHeaders(proxyRes, res);
|
||||
|
||||
// Queued streaming requests will already have a connection open and headers
|
||||
// sent due to the heartbeat handler. In that case we can just start
|
||||
// streaming the response without sending headers.
|
||||
if (!res.headersSent) {
|
||||
res.setHeader("Content-Type", "text/event-stream");
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
copyHeaders(proxyRes, res);
|
||||
res.flushHeaders();
|
||||
}
|
||||
|
||||
const chunks: Buffer[] = [];
|
||||
proxyRes.on("data", (chunk) => {
|
||||
|
@ -65,6 +72,24 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
{ error: err, api: req.api, key: req.key?.hash },
|
||||
`Error while streaming response.`
|
||||
);
|
||||
// OAI's spec doesn't allow for error events and clients wouldn't know
|
||||
// what to do with them anyway, so we'll just send a completion event
|
||||
// with the error message.
|
||||
const fakeErrorEvent = {
|
||||
id: "chatcmpl-error",
|
||||
object: "chat.completion.chunk",
|
||||
created: Date.now(),
|
||||
model: "",
|
||||
choices: [
|
||||
{
|
||||
delta: { content: "[Proxy streaming error: " + err.message + "]" },
|
||||
index: 0,
|
||||
finish_reason: "error",
|
||||
},
|
||||
],
|
||||
};
|
||||
res.write(`data: ${JSON.stringify(fakeErrorEvent)}\n\n`);
|
||||
res.write("data: [DONE]\n\n");
|
||||
res.end();
|
||||
reject(err);
|
||||
});
|
||||
|
|
|
@ -3,10 +3,12 @@ import * as http from "http";
|
|||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import * as httpProxy from "http-proxy";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { keyPool } from "../../../key-management";
|
||||
import { logPrompt } from "./log-prompt";
|
||||
import { buildFakeSseMessage, enqueue, trackWaitTime } from "../../queue";
|
||||
import { handleStreamedResponse } from "./handle-streamed-response";
|
||||
import { logPrompt } from "./log-prompt";
|
||||
|
||||
export const QUOTA_ROUTES = ["/v1/chat/completions"];
|
||||
const DECODER_MAP = {
|
||||
|
@ -21,6 +23,13 @@ const isSupportedContentEncoding = (
|
|||
return contentEncoding in DECODER_MAP;
|
||||
};
|
||||
|
||||
class RetryableError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "RetryableError";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Either decodes or streams the entire response body and then passes it as the
|
||||
* last argument to the rest of the middleware stack.
|
||||
|
@ -54,7 +63,7 @@ export type ProxyResMiddleware = ProxyResHandlerWithBody[];
|
|||
* the client. Once the stream is closed, the finalized body will be attached
|
||||
* to res.body and the remaining middleware will execute.
|
||||
*/
|
||||
export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
|
||||
export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
||||
return async (
|
||||
proxyRes: http.IncomingMessage,
|
||||
req: Request,
|
||||
|
@ -66,43 +75,23 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
|
|||
|
||||
let lastMiddlewareName = initialHandler.name;
|
||||
|
||||
req.log.debug(
|
||||
{
|
||||
api: req.api,
|
||||
route: req.path,
|
||||
method: req.method,
|
||||
stream: req.isStreaming,
|
||||
middleware: lastMiddlewareName,
|
||||
},
|
||||
"Handling proxy response"
|
||||
);
|
||||
|
||||
try {
|
||||
const body = await initialHandler(proxyRes, req, res);
|
||||
|
||||
const middlewareStack: ProxyResMiddleware = [];
|
||||
|
||||
if (req.isStreaming) {
|
||||
// Anything that touches the response will break streaming requests so
|
||||
// certain middleware can't be used. This includes whatever API-specific
|
||||
// middleware is passed in, which isn't ideal but it's what we've got
|
||||
// for now.
|
||||
// Streamed requests will be treated as non-streaming if the upstream
|
||||
// service returns a non-200 status code, so no need to include the
|
||||
// error handler here.
|
||||
|
||||
// This is a little too easy to accidentally screw up so I need to add a
|
||||
// better way to differentiate between middleware that can be used for
|
||||
// streaming requests and those that can't. Probably a separate type
|
||||
// or function signature for streaming-compatible middleware.
|
||||
middlewareStack.push(incrementKeyUsage, logPrompt);
|
||||
// `handleStreamedResponse` writes to the response and ends it, so
|
||||
// we can only execute middleware that doesn't write to the response.
|
||||
middlewareStack.push(trackRateLimit, incrementKeyUsage, logPrompt);
|
||||
} else {
|
||||
middlewareStack.push(
|
||||
trackRateLimit,
|
||||
handleUpstreamErrors,
|
||||
incrementKeyUsage,
|
||||
copyHttpHeaders,
|
||||
logPrompt,
|
||||
...middleware
|
||||
...apiMiddleware
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -110,25 +99,31 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
|
|||
lastMiddlewareName = middleware.name;
|
||||
await middleware(proxyRes, req, res, body);
|
||||
}
|
||||
|
||||
trackWaitTime(req);
|
||||
} catch (error: any) {
|
||||
if (res.headersSent) {
|
||||
req.log.error(
|
||||
`Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`
|
||||
);
|
||||
// Either the upstream error handler got to it first, or we're mid-
|
||||
// stream and we can't do anything about it.
|
||||
// Hack: if the error is a retryable rate-limit error, the request has
|
||||
// been re-enqueued and we can just return without doing anything else.
|
||||
if (error instanceof RetryableError) {
|
||||
return;
|
||||
}
|
||||
|
||||
const errorData = {
|
||||
error: error.stack,
|
||||
thrownBy: lastMiddlewareName,
|
||||
key: req.key?.hash,
|
||||
};
|
||||
const message = `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`;
|
||||
logger.error(
|
||||
{
|
||||
error: error.stack,
|
||||
thrownBy: lastMiddlewareName,
|
||||
key: req.key?.hash,
|
||||
},
|
||||
message
|
||||
);
|
||||
if (res.headersSent) {
|
||||
req.log.error(errorData, message);
|
||||
// This should have already been handled by the error handler, but
|
||||
// just in case...
|
||||
if (!res.writableEnded) {
|
||||
res.end();
|
||||
}
|
||||
return;
|
||||
}
|
||||
logger.error(errorData, message);
|
||||
res
|
||||
.status(500)
|
||||
.json({ error: "Internal server error", proxy_note: message });
|
||||
|
@ -136,6 +131,15 @@ export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
|
|||
};
|
||||
};
|
||||
|
||||
function reenqueueRequest(req: Request) {
|
||||
req.log.info(
|
||||
{ key: req.key?.hash, retryCount: req.retryCount },
|
||||
`Re-enqueueing request due to rate-limit error`
|
||||
);
|
||||
req.retryCount++;
|
||||
enqueue(req);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the response from the upstream service and decodes the body if
|
||||
* necessary. If the response is JSON, it will be parsed and returned as an
|
||||
|
@ -160,8 +164,8 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
|||
proxyRes.on("data", (chunk) => chunks.push(chunk));
|
||||
proxyRes.on("end", async () => {
|
||||
let body = Buffer.concat(chunks);
|
||||
const contentEncoding = proxyRes.headers["content-encoding"];
|
||||
|
||||
const contentEncoding = proxyRes.headers["content-encoding"];
|
||||
if (contentEncoding) {
|
||||
if (isSupportedContentEncoding(contentEncoding)) {
|
||||
const decoder = DECODER_MAP[contentEncoding];
|
||||
|
@ -169,7 +173,10 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
|||
} else {
|
||||
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
|
||||
logger.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
|
||||
res.status(500).json({ error: errorMessage, contentEncoding });
|
||||
writeErrorResponse(res, 500, {
|
||||
error: errorMessage,
|
||||
contentEncoding,
|
||||
});
|
||||
return reject(errorMessage);
|
||||
}
|
||||
}
|
||||
|
@ -183,7 +190,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
|||
} catch (error: any) {
|
||||
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
|
||||
logger.warn({ error, key: req.key?.hash }, errorMessage);
|
||||
res.status(500).json({ error: errorMessage });
|
||||
writeErrorResponse(res, 500, { error: errorMessage });
|
||||
return reject(errorMessage);
|
||||
}
|
||||
});
|
||||
|
@ -197,7 +204,9 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
|||
* Handles non-2xx responses from the upstream service. If the proxied response
|
||||
* is an error, this will respond to the client with an error payload and throw
|
||||
* an error to stop the middleware stack.
|
||||
* @throws {Error} HTTP error status code from upstream service
|
||||
* On 429 errors, if request queueing is enabled, the request will be silently
|
||||
* re-enqueued. Otherwise, the request will be rejected with an error payload.
|
||||
* @throws {Error} On HTTP error status code from upstream service
|
||||
*/
|
||||
const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
proxyRes,
|
||||
|
@ -206,6 +215,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
body
|
||||
) => {
|
||||
const statusCode = proxyRes.statusCode || 500;
|
||||
|
||||
if (statusCode < 400) {
|
||||
return;
|
||||
}
|
||||
|
@ -222,7 +232,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
if (typeof body === "object") {
|
||||
errorPayload = body;
|
||||
} else {
|
||||
throw new Error("Received non-JSON error response from upstream.");
|
||||
throw new Error("Received unparsable error response from upstream.");
|
||||
}
|
||||
} catch (parseError: any) {
|
||||
const statusMessage = proxyRes.statusMessage || "Unknown error";
|
||||
|
@ -238,8 +248,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
error: parseError.message,
|
||||
proxy_note: `This is likely a temporary error with the upstream service.`,
|
||||
};
|
||||
|
||||
res.status(statusCode).json(errorObject);
|
||||
writeErrorResponse(res, statusCode, errorObject);
|
||||
throw new Error(parseError.message);
|
||||
}
|
||||
|
||||
|
@ -261,30 +270,35 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
keyPool.disable(req.key!);
|
||||
errorPayload.proxy_note = `The OpenAI key is invalid or revoked. ${tryAgainMessage}`;
|
||||
} else if (statusCode === 429) {
|
||||
// One of:
|
||||
// - Quota exceeded (key is dead, disable it)
|
||||
// - Rate limit exceeded (key is fine, just try again)
|
||||
// - Model overloaded (their fault, just try again)
|
||||
if (errorPayload.error?.type === "insufficient_quota") {
|
||||
const type = errorPayload.error?.type;
|
||||
if (type === "insufficient_quota") {
|
||||
// Billing quota exceeded (key is dead, disable it)
|
||||
keyPool.disable(req.key!);
|
||||
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
|
||||
} else if (errorPayload.error?.type === "billing_not_active") {
|
||||
} else if (type === "billing_not_active") {
|
||||
// Billing is not active (key is dead, disable it)
|
||||
keyPool.disable(req.key!);
|
||||
errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`;
|
||||
} else if (type === "requests" || type === "tokens") {
|
||||
// Per-minute request or token rate limit is exceeded, which we can retry
|
||||
keyPool.markRateLimited(req.key!.hash);
|
||||
if (config.queueMode !== "none") {
|
||||
reenqueueRequest(req);
|
||||
// TODO: I don't like using an error to control flow here
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
}
|
||||
errorPayload.proxy_note = `Assigned key's '${type}' rate limit has been exceeded. Try again later.`;
|
||||
} else {
|
||||
// OpenAI probably overloaded
|
||||
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
|
||||
}
|
||||
} else if (statusCode === 404) {
|
||||
// Most likely model not found
|
||||
// TODO: this probably doesn't handle GPT-4-32k variants properly if the
|
||||
// proxy has keys for both the 8k and 32k context models at the same time.
|
||||
if (errorPayload.error?.code === "model_not_found") {
|
||||
if (req.key!.isGpt4) {
|
||||
// Malicious users can request a model that `startsWith` gpt-4 but is
|
||||
// not actually a valid model name and force the key to be downgraded.
|
||||
// I don't feel like fixing this so I'm just going to disable the key
|
||||
// downgrading feature for now.
|
||||
// keyPool.downgradeKey(req.key?.hash);
|
||||
// errorPayload.proxy_note = `This key was incorrectly assigned to GPT-4. It has been downgraded to Turbo.`;
|
||||
errorPayload.proxy_note = `This key was incorrectly flagged as GPT-4, or you requested a GPT-4 snapshot for which this key is not authorized. Try again to get a different key, or use Turbo.`;
|
||||
errorPayload.proxy_note = `Assigned key isn't provisioned for the GPT-4 snapshot you requested. Try again to get a different key, or use Turbo.`;
|
||||
} else {
|
||||
errorPayload.proxy_note = `No model was found for this key.`;
|
||||
}
|
||||
|
@ -301,10 +315,33 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
);
|
||||
}
|
||||
|
||||
res.status(statusCode).json(errorPayload);
|
||||
writeErrorResponse(res, statusCode, errorPayload);
|
||||
throw new Error(errorPayload.error?.message);
|
||||
};
|
||||
|
||||
function writeErrorResponse(
|
||||
res: Response,
|
||||
statusCode: number,
|
||||
errorPayload: Record<string, any>
|
||||
) {
|
||||
// If we're mid-SSE stream, send a data event with the error payload and end
|
||||
// the stream. Otherwise just send a normal error response.
|
||||
if (
|
||||
res.headersSent ||
|
||||
res.getHeader("content-type") === "text/event-stream"
|
||||
) {
|
||||
const msg = buildFakeSseMessage(
|
||||
`upstream error (${statusCode})`,
|
||||
JSON.stringify(errorPayload, null, 2)
|
||||
);
|
||||
res.write(msg);
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
res.status(statusCode).json(errorPayload);
|
||||
}
|
||||
}
|
||||
|
||||
/** Handles errors in rewriter pipelines. */
|
||||
export const handleInternalError: httpProxy.ErrorCallback = (
|
||||
err,
|
||||
|
@ -313,19 +350,14 @@ export const handleInternalError: httpProxy.ErrorCallback = (
|
|||
) => {
|
||||
logger.error({ error: err }, "Error in http-proxy-middleware pipeline.");
|
||||
try {
|
||||
if ("setHeader" in res && !res.headersSent) {
|
||||
res.writeHead(500, { "Content-Type": "application/json" });
|
||||
}
|
||||
res.end(
|
||||
JSON.stringify({
|
||||
error: {
|
||||
type: "proxy_error",
|
||||
message: err.message,
|
||||
stack: err.stack,
|
||||
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
|
||||
},
|
||||
})
|
||||
);
|
||||
writeErrorResponse(res as Response, 500, {
|
||||
error: {
|
||||
type: "proxy_error",
|
||||
message: err.message,
|
||||
stack: err.stack,
|
||||
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
|
||||
},
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
{ error: e },
|
||||
|
@ -340,6 +372,10 @@ const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
|||
}
|
||||
};
|
||||
|
||||
const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
|
||||
keyPool.updateRateLimits(req.key!.hash, proxyRes.headers);
|
||||
};
|
||||
|
||||
const copyHttpHeaders: ProxyResHandlerWithBody = async (
|
||||
proxyRes,
|
||||
_req,
|
||||
|
|
|
@ -3,6 +3,7 @@ import * as http from "http";
|
|||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import {
|
||||
addKey,
|
||||
|
@ -72,6 +73,7 @@ const openaiProxy = createProxyMiddleware({
|
|||
selfHandleResponse: true,
|
||||
logger,
|
||||
});
|
||||
const queuedOpenaiProxy = createQueueMiddleware(openaiProxy);
|
||||
|
||||
const openaiRouter = Router();
|
||||
// Some clients don't include the /v1/ prefix in their requests and users get
|
||||
|
@ -84,7 +86,7 @@ openaiRouter.use((req, _res, next) => {
|
|||
next();
|
||||
});
|
||||
openaiRouter.get("/v1/models", openaiProxy);
|
||||
openaiRouter.post("/v1/chat/completions", ipLimiter, openaiProxy);
|
||||
openaiRouter.post("/v1/chat/completions", ipLimiter, queuedOpenaiProxy);
|
||||
// If a browser tries to visit a route that doesn't exist, redirect to the info
|
||||
// page to help them find the right URL.
|
||||
openaiRouter.get("*", (req, res, next) => {
|
||||
|
|
|
@ -0,0 +1,367 @@
|
|||
/**
|
||||
* Very scuffed request queue. OpenAI's GPT-4 keys have a very strict rate limit
|
||||
* of 40000 generated tokens per minute. We don't actually know how many tokens
|
||||
* a given key has generated, so our queue will simply retry requests that fail
|
||||
* with a non-billing related 429 over and over again until they succeed.
|
||||
*
|
||||
* Dequeueing can operate in one of two modes:
|
||||
* - 'fair': requests are dequeued in the order they were enqueued.
|
||||
* - 'random': requests are dequeued randomly, not really a queue at all.
|
||||
*
|
||||
* When a request to a proxied endpoint is received, we create a closure around
|
||||
* the call to http-proxy-middleware and attach it to the request. This allows
|
||||
* us to pause the request until we have a key available. Further, if the
|
||||
* proxied request encounters a retryable error, we can simply put the request
|
||||
* back in the queue and it will be retried later using the same closure.
|
||||
*/
|
||||
|
||||
import type { Handler, Request } from "express";
|
||||
import { config, DequeueMode } from "../config";
|
||||
import { keyPool } from "../key-management";
|
||||
import { logger } from "../logger";
|
||||
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
|
||||
let dequeueMode: DequeueMode = "fair";
|
||||
|
||||
/** Maximum number of queue slots for Agnai.chat requests. */
|
||||
const AGNAI_CONCURRENCY_LIMIT = 15;
|
||||
/** Maximum number of queue slots for individual users. */
|
||||
const USER_CONCURRENCY_LIMIT = 1;
|
||||
|
||||
export function enqueue(req: Request) {
|
||||
// All agnai.chat requests come from the same IP, so we allow them to have
|
||||
// more spots in the queue. Can't make it unlimited because people will
|
||||
// intentionally abuse it.
|
||||
const maxConcurrentQueuedRequests =
|
||||
req.ip === AGNAI_DOT_CHAT_IP
|
||||
? AGNAI_CONCURRENCY_LIMIT
|
||||
: USER_CONCURRENCY_LIMIT;
|
||||
const reqCount = queue.filter((r) => r.ip === req.ip).length;
|
||||
if (reqCount >= maxConcurrentQueuedRequests) {
|
||||
if (req.ip === AGNAI_DOT_CHAT_IP) {
|
||||
// Re-enqueued requests are not counted towards the limit since they
|
||||
// already made it through the queue once.
|
||||
if (req.retryCount === 0) {
|
||||
throw new Error("Too many agnai.chat requests are already queued");
|
||||
}
|
||||
} else {
|
||||
throw new Error("Request is already queued for this IP");
|
||||
}
|
||||
}
|
||||
|
||||
queue.push(req);
|
||||
req.queueOutTime = 0;
|
||||
|
||||
// shitty hack to remove hpm's event listeners on retried requests
|
||||
removeProxyMiddlewareEventListeners(req);
|
||||
|
||||
// If the request opted into streaming, we need to register a heartbeat
|
||||
// handler to keep the connection alive while it waits in the queue. We
|
||||
// deregister the handler when the request is dequeued.
|
||||
if (req.body.stream) {
|
||||
const res = req.res!;
|
||||
if (!res.headersSent) {
|
||||
initStreaming(req);
|
||||
}
|
||||
req.heartbeatInterval = setInterval(() => {
|
||||
if (process.env.NODE_ENV === "production") {
|
||||
req.res!.write(": queue heartbeat\n\n");
|
||||
} else {
|
||||
req.log.info(`Sending heartbeat to request in queue.`);
|
||||
const avgWait = Math.round(getAverageWaitTime() / 1000);
|
||||
const currentDuration = Math.round((Date.now() - req.startTime) / 1000);
|
||||
const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`;
|
||||
req.res!.write(buildFakeSseMessage("heartbeat", debugMsg));
|
||||
}
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
// Register a handler to remove the request from the queue if the connection
|
||||
// is aborted or closed before it is dequeued.
|
||||
const removeFromQueue = () => {
|
||||
req.log.info(`Removing aborted request from queue.`);
|
||||
const index = queue.indexOf(req);
|
||||
if (index !== -1) {
|
||||
queue.splice(index, 1);
|
||||
}
|
||||
if (req.heartbeatInterval) {
|
||||
clearInterval(req.heartbeatInterval);
|
||||
}
|
||||
};
|
||||
req.onAborted = removeFromQueue;
|
||||
req.res!.once("close", removeFromQueue);
|
||||
|
||||
if (req.retryCount ?? 0 > 0) {
|
||||
req.log.info({ retries: req.retryCount }, `Enqueued request for retry.`);
|
||||
} else {
|
||||
req.log.info(`Enqueued new request.`);
|
||||
}
|
||||
}
|
||||
|
||||
export function dequeue(model: string): Request | undefined {
|
||||
// TODO: This should be set by some middleware that checks the request body.
|
||||
const modelQueue =
|
||||
model === "gpt-4"
|
||||
? queue.filter((req) => req.body.model?.startsWith("gpt-4"))
|
||||
: queue.filter((req) => !req.body.model?.startsWith("gpt-4"));
|
||||
|
||||
if (modelQueue.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let req: Request;
|
||||
|
||||
if (dequeueMode === "fair") {
|
||||
// Dequeue the request that has been waiting the longest
|
||||
req = modelQueue.reduce((prev, curr) =>
|
||||
prev.startTime < curr.startTime ? prev : curr
|
||||
);
|
||||
} else {
|
||||
// Dequeue a random request
|
||||
const index = Math.floor(Math.random() * modelQueue.length);
|
||||
req = modelQueue[index];
|
||||
}
|
||||
queue.splice(queue.indexOf(req), 1);
|
||||
|
||||
if (req.onAborted) {
|
||||
req.res!.off("close", req.onAborted);
|
||||
req.onAborted = undefined;
|
||||
}
|
||||
|
||||
if (req.heartbeatInterval) {
|
||||
clearInterval(req.heartbeatInterval);
|
||||
}
|
||||
|
||||
// Track the time leaving the queue now, but don't add it to the wait times
|
||||
// yet because we don't know if the request will succeed or fail. We track
|
||||
// the time now and not after the request succeeds because we don't want to
|
||||
// include the model processing time.
|
||||
req.queueOutTime = Date.now();
|
||||
return req;
|
||||
}
|
||||
|
||||
/**
|
||||
* Naive way to keep the queue moving by continuously dequeuing requests. Not
|
||||
* ideal because it limits throughput but we probably won't have enough traffic
|
||||
* or keys for this to be a problem. If it does we can dequeue multiple
|
||||
* per tick.
|
||||
**/
|
||||
function processQueue() {
|
||||
// This isn't completely correct, because a key can service multiple models.
|
||||
// Currently if a key is locked out on one model it will also stop servicing
|
||||
// the others, because we only track one rate limit per key.
|
||||
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
|
||||
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
|
||||
|
||||
const reqs: (Request | undefined)[] = [];
|
||||
if (gpt4Lockout === 0) {
|
||||
reqs.push(dequeue("gpt-4"));
|
||||
}
|
||||
if (turboLockout === 0) {
|
||||
reqs.push(dequeue("gpt-3.5-turbo"));
|
||||
}
|
||||
|
||||
reqs.filter(Boolean).forEach((req) => {
|
||||
if (req?.proceed) {
|
||||
req.log.info({ retries: req.retryCount }, `Dequeuing request.`);
|
||||
req.proceed();
|
||||
}
|
||||
});
|
||||
setTimeout(processQueue, 50);
|
||||
}
|
||||
|
||||
/**
|
||||
* Kill stalled requests after 5 minutes, and remove tracked wait times after 2
|
||||
* minutes.
|
||||
**/
|
||||
function cleanQueue() {
|
||||
const now = Date.now();
|
||||
const oldRequests = queue.filter(
|
||||
(req) => now - (req.startTime ?? now) > 5 * 60 * 1000
|
||||
);
|
||||
oldRequests.forEach((req) => {
|
||||
req.log.info(`Removing request from queue after 5 minutes.`);
|
||||
killQueuedRequest(req);
|
||||
});
|
||||
|
||||
const index = waitTimes.findIndex(
|
||||
(waitTime) => now - waitTime.end > 120 * 1000
|
||||
);
|
||||
const removed = waitTimes.splice(0, index + 1);
|
||||
log.info(
|
||||
{ stalledRequests: oldRequests.length, prunedWaitTimes: removed.length },
|
||||
`Cleaning up request queue.`
|
||||
);
|
||||
setTimeout(cleanQueue, 20 * 1000);
|
||||
}
|
||||
|
||||
export function start() {
|
||||
processQueue();
|
||||
cleanQueue();
|
||||
log.info(`Started request queue.`);
|
||||
}
|
||||
|
||||
const waitTimes: { start: number; end: number }[] = [];
|
||||
|
||||
/** Adds a successful request to the list of wait times. */
|
||||
export function trackWaitTime(req: Request) {
|
||||
waitTimes.push({
|
||||
start: req.startTime!,
|
||||
end: req.queueOutTime ?? Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
/** Returns average wait time in milliseconds. */
|
||||
export function getAverageWaitTime() {
|
||||
if (waitTimes.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Include requests that are still in the queue right now
|
||||
const now = Date.now();
|
||||
const waitTimesWithCurrent = [
|
||||
...waitTimes,
|
||||
...queue.map((req) => ({
|
||||
start: req.startTime!,
|
||||
end: now,
|
||||
})),
|
||||
];
|
||||
|
||||
return (
|
||||
waitTimesWithCurrent.reduce((acc, curr) => acc + curr.end - curr.start, 0) /
|
||||
waitTimesWithCurrent.length
|
||||
);
|
||||
}
|
||||
|
||||
export function getQueueLength() {
|
||||
return queue.length;
|
||||
}
|
||||
|
||||
export function createQueueMiddleware(proxyMiddleware: Handler): Handler {
|
||||
return (req, res, next) => {
|
||||
if (config.queueMode === "none") {
|
||||
return proxyMiddleware(req, res, next);
|
||||
}
|
||||
|
||||
req.proceed = () => {
|
||||
proxyMiddleware(req, res, next);
|
||||
};
|
||||
|
||||
try {
|
||||
enqueue(req);
|
||||
} catch (err: any) {
|
||||
req.res!.status(429).json({
|
||||
type: "proxy_error",
|
||||
message: err.message,
|
||||
stack: err.stack,
|
||||
proxy_note: `Only one request per IP can be queued at a time. If you don't have another request queued, your IP may be in use by another user.`,
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function killQueuedRequest(req: Request) {
|
||||
if (!req.res || req.res.writableEnded) {
|
||||
req.log.warn(`Attempted to terminate request that has already ended.`);
|
||||
return;
|
||||
}
|
||||
const res = req.res;
|
||||
try {
|
||||
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`;
|
||||
if (res.headersSent) {
|
||||
const fakeErrorEvent = buildFakeSseMessage("proxy queue error", message);
|
||||
res.write(fakeErrorEvent);
|
||||
res.end();
|
||||
} else {
|
||||
res.status(500).json({ error: message });
|
||||
}
|
||||
} catch (e) {
|
||||
req.log.error(e, `Error killing stalled request.`);
|
||||
}
|
||||
}
|
||||
|
||||
function initStreaming(req: Request) {
|
||||
req.log.info(`Initiating streaming for new queued request.`);
|
||||
const res = req.res!;
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Content-Type", "text/event-stream");
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
res.flushHeaders();
|
||||
res.write("\n");
|
||||
res.write(": joining queue\n\n");
|
||||
}
|
||||
|
||||
export function buildFakeSseMessage(type: string, string: string) {
|
||||
const fakeEvent = {
|
||||
id: "chatcmpl-" + type,
|
||||
object: "chat.completion.chunk",
|
||||
created: Date.now(),
|
||||
model: "",
|
||||
choices: [
|
||||
{
|
||||
delta: { content: `[${type}: ${string}]\n` },
|
||||
index: 0,
|
||||
finish_reason: type,
|
||||
},
|
||||
],
|
||||
};
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
}
|
||||
|
||||
/**
|
||||
* http-proxy-middleware attaches a bunch of event listeners to the req and
|
||||
* res objects which causes problems with our approach to re-enqueuing failed
|
||||
* proxied requests. This function removes those event listeners.
|
||||
* We don't have references to the original event listeners, so we have to
|
||||
* look through the list and remove HPM's listeners by looking for particular
|
||||
* strings in the listener functions. This is an astoundingly shitty way to do
|
||||
* this, but it's the best I can come up with.
|
||||
*/
|
||||
function removeProxyMiddlewareEventListeners(req: Request) {
|
||||
// node_modules/http-proxy-middleware/dist/plugins/default/debug-proxy-errors-plugin.js:29
|
||||
// res.listeners('close')
|
||||
const RES_ONCLOSE = `Destroying proxyRes in proxyRes close event`;
|
||||
// node_modules/http-proxy-middleware/dist/plugins/default/debug-proxy-errors-plugin.js:19
|
||||
// res.listeners('error')
|
||||
const RES_ONERROR = `Socket error in proxyReq event`;
|
||||
// node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js:146
|
||||
// req.listeners('aborted')
|
||||
const REQ_ONABORTED = `proxyReq.abort()`;
|
||||
// node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js:156
|
||||
// req.listeners('error')
|
||||
const REQ_ONERROR = `if (req.socket.destroyed`;
|
||||
|
||||
const res = req.res!;
|
||||
|
||||
const resOnClose = res
|
||||
.listeners("close")
|
||||
.find((listener) => listener.toString().includes(RES_ONCLOSE));
|
||||
if (resOnClose) {
|
||||
res.removeListener("close", resOnClose as any);
|
||||
}
|
||||
|
||||
const resOnError = res
|
||||
.listeners("error")
|
||||
.find((listener) => listener.toString().includes(RES_ONERROR));
|
||||
if (resOnError) {
|
||||
res.removeListener("error", resOnError as any);
|
||||
}
|
||||
|
||||
const reqOnAborted = req
|
||||
.listeners("aborted")
|
||||
.find((listener) => listener.toString().includes(REQ_ONABORTED));
|
||||
if (reqOnAborted) {
|
||||
req.removeListener("aborted", reqOnAborted as any);
|
||||
}
|
||||
|
||||
const reqOnError = req
|
||||
.listeners("error")
|
||||
.find((listener) => listener.toString().includes(REQ_ONERROR));
|
||||
if (reqOnError) {
|
||||
req.removeListener("error", reqOnError as any);
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ import { Request, Response, NextFunction } from "express";
|
|||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
|
||||
export const AGNAI_DOT_CHAT_IP = "157.230.249.32";
|
||||
const RATE_LIMIT_ENABLED = Boolean(config.modelRateLimit);
|
||||
const RATE_LIMIT = Math.max(1, config.modelRateLimit);
|
||||
const ONE_MINUTE_MS = 60 * 1000;
|
||||
|
@ -66,7 +67,7 @@ export const ipLimiter = (req: Request, res: Response, next: NextFunction) => {
|
|||
// Exempt Agnai.chat from rate limiting since it's shared between a lot of
|
||||
// users. Dunno how to prevent this from being abused without some sort of
|
||||
// identifier sent from Agnaistic to identify specific users.
|
||||
if (req.ip === "157.230.249.32") {
|
||||
if (req.ip === AGNAI_DOT_CHAT_IP) {
|
||||
next();
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import { keyPool } from "./key-management";
|
|||
import { proxyRouter, rewriteTavernRequests } from "./proxy/routes";
|
||||
import { handleInfoPage } from "./info-page";
|
||||
import { logQueue } from "./prompt-logging";
|
||||
import { start as startRequestQueue } from "./proxy/queue";
|
||||
|
||||
const PORT = config.port;
|
||||
|
||||
|
@ -17,6 +18,7 @@ const app = express();
|
|||
app.use("/", rewriteTavernRequests);
|
||||
app.use(
|
||||
pinoHttp({
|
||||
quietReqLogger: true,
|
||||
logger,
|
||||
// SillyTavern spams the hell out of this endpoint so don't log it
|
||||
autoLogging: { ignore: (req) => req.url === "/proxy/kobold/api/v1/model" },
|
||||
|
@ -31,6 +33,11 @@ app.use(
|
|||
},
|
||||
})
|
||||
);
|
||||
app.use((req, _res, next) => {
|
||||
req.startTime = Date.now();
|
||||
req.retryCount = 0;
|
||||
next();
|
||||
});
|
||||
app.use(cors());
|
||||
app.use(
|
||||
express.json({ limit: "10mb" }),
|
||||
|
@ -40,21 +47,31 @@ app.use(
|
|||
// deploy this somewhere without a load balancer then incoming requests can
|
||||
// spoof the X-Forwarded-For header and bypass the rate limiting.
|
||||
app.set("trust proxy", true);
|
||||
|
||||
// routes
|
||||
app.get("/", handleInfoPage);
|
||||
app.use("/proxy", proxyRouter);
|
||||
|
||||
// 500 and 404
|
||||
app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => {
|
||||
if (err.status) {
|
||||
res.status(err.status).json({ error: err.message });
|
||||
} else {
|
||||
logger.error(err);
|
||||
res.status(500).json({ error: "Internal server error" });
|
||||
res.status(500).json({
|
||||
error: {
|
||||
type: "proxy_error",
|
||||
message: err.message,
|
||||
stack: err.stack,
|
||||
proxy_note: `Reverse proxy encountered an internal server error.`,
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
app.use((_req: unknown, res: express.Response) => {
|
||||
res.status(404).json({ error: "Not found" });
|
||||
});
|
||||
|
||||
// start server and load keys
|
||||
app.listen(PORT, async () => {
|
||||
try {
|
||||
|
@ -108,4 +125,8 @@ app.listen(PORT, async () => {
|
|||
logger.info("Starting prompt logging...");
|
||||
logQueue.start();
|
||||
}
|
||||
if (config.queueMode !== "none") {
|
||||
logger.info("Starting request queue...");
|
||||
startRequestQueue();
|
||||
}
|
||||
});
|
||||
|
|
|
@ -7,6 +7,12 @@ declare global {
|
|||
key?: Key;
|
||||
api: "kobold" | "openai" | "anthropic";
|
||||
isStreaming?: boolean;
|
||||
startTime: number;
|
||||
retryCount: number;
|
||||
queueOutTime?: number;
|
||||
onAborted?: () => void;
|
||||
proceed: () => void;
|
||||
heartbeatInterval?: NodeJS.Timeout;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue