Scale SSE heartbeat size with traffic (khanon/oai-reverse-proxy!53)

This commit is contained in:
khanon 2023-11-16 05:45:35 +00:00
parent 6acdf35914
commit 6aa6bebf08
3 changed files with 112 additions and 34 deletions

View File

@ -11,6 +11,7 @@
* back in the queue and it will be retried later using the same closure. * back in the queue and it will be retried later using the same closure.
*/ */
import crypto from "crypto";
import type { Handler, Request } from "express"; import type { Handler, Request } from "express";
import { keyPool } from "../shared/key-management"; import { keyPool } from "../shared/key-management";
import { import {
@ -23,7 +24,7 @@ import {
import { buildFakeSse, initializeSseStream } from "../shared/streaming"; import { buildFakeSse, initializeSseStream } from "../shared/streaming";
import { assertNever } from "../shared/utils"; import { assertNever } from "../shared/utils";
import { logger } from "../logger"; import { logger } from "../logger";
import { SHARED_IP_ADDRESSES } from "./rate-limit"; import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request"; import { RequestPreprocessor } from "./middleware/request";
const queue: Request[] = []; const queue: Request[] = [];
@ -33,6 +34,15 @@ const log = logger.child({ module: "request-queue" });
const AGNAI_CONCURRENCY_LIMIT = 5; const AGNAI_CONCURRENCY_LIMIT = 5;
/** Maximum number of queue slots for individual users. */ /** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = 1; const USER_CONCURRENCY_LIMIT = 1;
const MIN_HEARTBEAT_SIZE = 512;
const MAX_HEARTBEAT_SIZE =
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
const HEARTBEAT_INTERVAL =
1000 * parseInt(process.env.HEARTBEAT_INTERVAL_SEC ?? "5");
const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "50");
const PAYLOAD_SCALE_FACTOR = parseFloat(
process.env.PAYLOAD_SCALE_FACTOR ?? "6"
);
/** /**
* Returns an identifier for a request. This is used to determine if a * Returns an identifier for a request. This is used to determine if a
@ -93,31 +103,21 @@ export function enqueue(req: Request) {
if (!res.headersSent) { if (!res.headersSent) {
initStreaming(req); initStreaming(req);
} }
req.heartbeatInterval = setInterval(() => { registerHeartbeat(req);
if (process.env.NODE_ENV === "production") { } else if (getProxyLoad() > LOAD_THRESHOLD) {
if (!req.query.badSseParser) req.res!.write(": queue heartbeat\n\n"); throw new Error(
} else { "Due to heavy traffic on this proxy, you must enable streaming for your request."
req.log.info(`Sending heartbeat to request in queue.`); );
const partition = getPartitionForRequest(req);
const avgWait = Math.round(getEstimatedWaitTime(partition) / 1000);
const currentDuration = Math.round((Date.now() - req.startTime) / 1000);
const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`;
req.res!.write(buildFakeSse("heartbeat", debugMsg, req));
}
}, 10000);
} }
// Register a handler to remove the request from the queue if the connection
// is aborted or closed before it is dequeued.
const removeFromQueue = () => { const removeFromQueue = () => {
req.log.info(`Removing aborted request from queue.`); req.log.info(`Removing aborted request from queue.`);
const index = queue.indexOf(req); const index = queue.indexOf(req);
if (index !== -1) { if (index !== -1) {
queue.splice(index, 1); queue.splice(index, 1);
} }
if (req.heartbeatInterval) { if (req.heartbeatInterval) clearInterval(req.heartbeatInterval);
clearInterval(req.heartbeatInterval); if (req.monitorInterval) clearInterval(req.monitorInterval);
}
}; };
req.onAborted = removeFromQueue; req.onAborted = removeFromQueue;
req.res!.once("close", removeFromQueue); req.res!.once("close", removeFromQueue);
@ -188,9 +188,8 @@ export function dequeue(partition: ModelFamily): Request | undefined {
req.onAborted = undefined; req.onAborted = undefined;
} }
if (req.heartbeatInterval) { if (req.heartbeatInterval) clearInterval(req.heartbeatInterval);
clearInterval(req.heartbeatInterval); if (req.monitorInterval) clearInterval(req.monitorInterval);
}
// Track the time leaving the queue now, but don't add it to the wait times // Track the time leaving the queue now, but don't add it to the wait times
// yet because we don't know if the request will succeed or fail. We track // yet because we don't know if the request will succeed or fail. We track
@ -385,6 +384,7 @@ export function createQueueMiddleware({
function killQueuedRequest(req: Request) { function killQueuedRequest(req: Request) {
if (!req.res || req.res.writableEnded) { if (!req.res || req.res.writableEnded) {
req.log.warn(`Attempted to terminate request that has already ended.`); req.log.warn(`Attempted to terminate request that has already ended.`);
queue.splice(queue.indexOf(req), 1);
return; return;
} }
const res = req.res; const res = req.res;
@ -469,3 +469,84 @@ function removeProxyMiddlewareEventListeners(req: Request) {
req.removeListener("error", reqOnError as any); req.removeListener("error", reqOnError as any);
} }
} }
export function registerHeartbeat(req: Request) {
const res = req.res!;
let isBufferFull = false;
let bufferFullCount = 0;
req.heartbeatInterval = setInterval(() => {
if (isBufferFull) {
bufferFullCount++;
if (bufferFullCount >= 3) {
req.log.error("Heartbeat skipped too many times; killing connection.");
res.destroy();
} else {
req.log.warn({ bufferFullCount }, "Heartbeat skipped; buffer is full.");
}
return;
}
const data = getHeartbeatPayload();
if (!res.write(data)) {
isBufferFull = true;
res.once("drain", () => (isBufferFull = false));
}
}, HEARTBEAT_INTERVAL);
monitorHeartbeat(req);
}
function monitorHeartbeat(req: Request) {
const res = req.res!;
let lastBytesSent = 0;
req.monitorInterval = setInterval(() => {
const bytesSent = res.socket?.bytesWritten ?? 0;
const bytesSinceLast = bytesSent - lastBytesSent;
req.log.debug(
{
previousBytesSent: lastBytesSent,
currentBytesSent: bytesSent,
},
"Heartbeat monitor check."
);
lastBytesSent = bytesSent;
const minBytes = Math.floor(getHeartbeatSize() / 2);
if (bytesSinceLast < minBytes) {
req.log.warn(
{ minBytes, bytesSinceLast },
"Queued request is processing heartbeats enough data or server is overloaded; killing connection."
);
res.destroy();
}
}, HEARTBEAT_INTERVAL * 2);
}
/** Sends larger heartbeats when the queue is overloaded */
function getHeartbeatSize() {
const load = getProxyLoad();
if (load <= LOAD_THRESHOLD) {
return MIN_HEARTBEAT_SIZE;
} else {
const excessLoad = load - LOAD_THRESHOLD;
const size =
MIN_HEARTBEAT_SIZE + Math.pow(excessLoad * PAYLOAD_SCALE_FACTOR, 2);
if (size > MAX_HEARTBEAT_SIZE) return MAX_HEARTBEAT_SIZE;
return size;
}
}
function getHeartbeatPayload() {
const size = getHeartbeatSize();
const data =
process.env.NODE_ENV === "production"
? crypto.randomBytes(size)
: `payload size: ${size}`;
return `: queue heartbeat ${data}\n\n`;
}
function getProxyLoad() {
return Math.max(getUniqueIps(), queue.length);
}

View File

@ -39,11 +39,7 @@ export function copySseResponseHeaders(
* that the request is being proxied to. Used to send error messages to the * that the request is being proxied to. Used to send error messages to the
* client in the middle of a streaming request. * client in the middle of a streaming request.
*/ */
export function buildFakeSse( export function buildFakeSse(type: string, string: string, req: Request) {
type: string,
string: string,
req: Request
) {
let fakeEvent; let fakeEvent;
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`; const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`;
@ -54,7 +50,7 @@ export function buildFakeSse(
object: "chat.completion.chunk", object: "chat.completion.chunk",
created: Date.now(), created: Date.now(),
model: req.body?.model, model: req.body?.model,
choices: [{ delta: { content }, index: 0, finish_reason: type }] choices: [{ delta: { content }, index: 0, finish_reason: type }],
}; };
break; break;
case "openai-text": case "openai-text":
@ -63,9 +59,9 @@ export function buildFakeSse(
object: "text_completion", object: "text_completion",
created: Date.now(), created: Date.now(),
choices: [ choices: [
{ text: content, index: 0, logprobs: null, finish_reason: type } { text: content, index: 0, logprobs: null, finish_reason: type },
], ],
model: req.body?.model model: req.body?.model,
}; };
break; break;
case "anthropic": case "anthropic":
@ -75,7 +71,7 @@ export function buildFakeSse(
truncated: false, // I've never seen this be true truncated: false, // I've never seen this be true
stop: null, stop: null,
model: req.body?.model, model: req.body?.model,
log_id: "proxy-req-" + req.id log_id: "proxy-req-" + req.id,
}; };
break; break;
case "google-palm": case "google-palm":
@ -86,10 +82,10 @@ export function buildFakeSse(
} }
if (req.inboundApi === "anthropic") { if (req.inboundApi === "anthropic") {
return [ return (
"event: completion", ["event: completion", `data: ${JSON.stringify(fakeEvent)}`].join("\n") +
`data: ${JSON.stringify(fakeEvent)}`, "\n\n"
].join("\n") + "\n\n"; );
} }
return `data: ${JSON.stringify(fakeEvent)}\n\n`; return `data: ${JSON.stringify(fakeEvent)}\n\n`;

View File

@ -22,6 +22,7 @@ declare global {
onAborted?: () => void; onAborted?: () => void;
proceed: () => void; proceed: () => void;
heartbeatInterval?: NodeJS.Timeout; heartbeatInterval?: NodeJS.Timeout;
monitorInterval?: NodeJS.Timeout;
promptTokens?: number; promptTokens?: number;
outputTokens?: number; outputTokens?: number;
tokenizerInfo: Record<string, any>; tokenizerInfo: Record<string, any>;