diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index 0de463e..584ef90 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -11,6 +11,7 @@ * back in the queue and it will be retried later using the same closure. */ +import crypto from "crypto"; import type { Handler, Request } from "express"; import { keyPool } from "../shared/key-management"; import { @@ -23,7 +24,7 @@ import { import { buildFakeSse, initializeSseStream } from "../shared/streaming"; import { assertNever } from "../shared/utils"; import { logger } from "../logger"; -import { SHARED_IP_ADDRESSES } from "./rate-limit"; +import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit"; import { RequestPreprocessor } from "./middleware/request"; const queue: Request[] = []; @@ -33,6 +34,15 @@ const log = logger.child({ module: "request-queue" }); const AGNAI_CONCURRENCY_LIMIT = 5; /** Maximum number of queue slots for individual users. */ const USER_CONCURRENCY_LIMIT = 1; +const MIN_HEARTBEAT_SIZE = 512; +const MAX_HEARTBEAT_SIZE = + 1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024"); +const HEARTBEAT_INTERVAL = + 1000 * parseInt(process.env.HEARTBEAT_INTERVAL_SEC ?? "5"); +const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "50"); +const PAYLOAD_SCALE_FACTOR = parseFloat( + process.env.PAYLOAD_SCALE_FACTOR ?? "6" +); /** * Returns an identifier for a request. This is used to determine if a @@ -93,31 +103,21 @@ export function enqueue(req: Request) { if (!res.headersSent) { initStreaming(req); } - req.heartbeatInterval = setInterval(() => { - if (process.env.NODE_ENV === "production") { - if (!req.query.badSseParser) req.res!.write(": queue heartbeat\n\n"); - } else { - req.log.info(`Sending heartbeat to request in queue.`); - const partition = getPartitionForRequest(req); - const avgWait = Math.round(getEstimatedWaitTime(partition) / 1000); - const currentDuration = Math.round((Date.now() - req.startTime) / 1000); - const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`; - req.res!.write(buildFakeSse("heartbeat", debugMsg, req)); - } - }, 10000); + registerHeartbeat(req); + } else if (getProxyLoad() > LOAD_THRESHOLD) { + throw new Error( + "Due to heavy traffic on this proxy, you must enable streaming for your request." + ); } - // Register a handler to remove the request from the queue if the connection - // is aborted or closed before it is dequeued. const removeFromQueue = () => { req.log.info(`Removing aborted request from queue.`); const index = queue.indexOf(req); if (index !== -1) { queue.splice(index, 1); } - if (req.heartbeatInterval) { - clearInterval(req.heartbeatInterval); - } + if (req.heartbeatInterval) clearInterval(req.heartbeatInterval); + if (req.monitorInterval) clearInterval(req.monitorInterval); }; req.onAborted = removeFromQueue; req.res!.once("close", removeFromQueue); @@ -188,9 +188,8 @@ export function dequeue(partition: ModelFamily): Request | undefined { req.onAborted = undefined; } - if (req.heartbeatInterval) { - clearInterval(req.heartbeatInterval); - } + if (req.heartbeatInterval) clearInterval(req.heartbeatInterval); + if (req.monitorInterval) clearInterval(req.monitorInterval); // Track the time leaving the queue now, but don't add it to the wait times // yet because we don't know if the request will succeed or fail. We track @@ -385,6 +384,7 @@ export function createQueueMiddleware({ function killQueuedRequest(req: Request) { if (!req.res || req.res.writableEnded) { req.log.warn(`Attempted to terminate request that has already ended.`); + queue.splice(queue.indexOf(req), 1); return; } const res = req.res; @@ -469,3 +469,84 @@ function removeProxyMiddlewareEventListeners(req: Request) { req.removeListener("error", reqOnError as any); } } + +export function registerHeartbeat(req: Request) { + const res = req.res!; + + let isBufferFull = false; + let bufferFullCount = 0; + req.heartbeatInterval = setInterval(() => { + if (isBufferFull) { + bufferFullCount++; + if (bufferFullCount >= 3) { + req.log.error("Heartbeat skipped too many times; killing connection."); + res.destroy(); + } else { + req.log.warn({ bufferFullCount }, "Heartbeat skipped; buffer is full."); + } + return; + } + + const data = getHeartbeatPayload(); + if (!res.write(data)) { + isBufferFull = true; + res.once("drain", () => (isBufferFull = false)); + } + }, HEARTBEAT_INTERVAL); + monitorHeartbeat(req); +} + +function monitorHeartbeat(req: Request) { + const res = req.res!; + + let lastBytesSent = 0; + req.monitorInterval = setInterval(() => { + const bytesSent = res.socket?.bytesWritten ?? 0; + const bytesSinceLast = bytesSent - lastBytesSent; + req.log.debug( + { + previousBytesSent: lastBytesSent, + currentBytesSent: bytesSent, + }, + "Heartbeat monitor check." + ); + lastBytesSent = bytesSent; + + const minBytes = Math.floor(getHeartbeatSize() / 2); + if (bytesSinceLast < minBytes) { + req.log.warn( + { minBytes, bytesSinceLast }, + "Queued request is processing heartbeats enough data or server is overloaded; killing connection." + ); + res.destroy(); + } + }, HEARTBEAT_INTERVAL * 2); +} + +/** Sends larger heartbeats when the queue is overloaded */ +function getHeartbeatSize() { + const load = getProxyLoad(); + + if (load <= LOAD_THRESHOLD) { + return MIN_HEARTBEAT_SIZE; + } else { + const excessLoad = load - LOAD_THRESHOLD; + const size = + MIN_HEARTBEAT_SIZE + Math.pow(excessLoad * PAYLOAD_SCALE_FACTOR, 2); + if (size > MAX_HEARTBEAT_SIZE) return MAX_HEARTBEAT_SIZE; + return size; + } +} + +function getHeartbeatPayload() { + const size = getHeartbeatSize(); + const data = + process.env.NODE_ENV === "production" + ? crypto.randomBytes(size) + : `payload size: ${size}`; + return `: queue heartbeat ${data}\n\n`; +} + +function getProxyLoad() { + return Math.max(getUniqueIps(), queue.length); +} diff --git a/src/shared/streaming.ts b/src/shared/streaming.ts index f3943b2..2fc59af 100644 --- a/src/shared/streaming.ts +++ b/src/shared/streaming.ts @@ -39,11 +39,7 @@ export function copySseResponseHeaders( * that the request is being proxied to. Used to send error messages to the * client in the middle of a streaming request. */ -export function buildFakeSse( - type: string, - string: string, - req: Request -) { +export function buildFakeSse(type: string, string: string, req: Request) { let fakeEvent; const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`; @@ -54,7 +50,7 @@ export function buildFakeSse( object: "chat.completion.chunk", created: Date.now(), model: req.body?.model, - choices: [{ delta: { content }, index: 0, finish_reason: type }] + choices: [{ delta: { content }, index: 0, finish_reason: type }], }; break; case "openai-text": @@ -63,9 +59,9 @@ export function buildFakeSse( object: "text_completion", created: Date.now(), choices: [ - { text: content, index: 0, logprobs: null, finish_reason: type } + { text: content, index: 0, logprobs: null, finish_reason: type }, ], - model: req.body?.model + model: req.body?.model, }; break; case "anthropic": @@ -75,7 +71,7 @@ export function buildFakeSse( truncated: false, // I've never seen this be true stop: null, model: req.body?.model, - log_id: "proxy-req-" + req.id + log_id: "proxy-req-" + req.id, }; break; case "google-palm": @@ -86,10 +82,10 @@ export function buildFakeSse( } if (req.inboundApi === "anthropic") { - return [ - "event: completion", - `data: ${JSON.stringify(fakeEvent)}`, - ].join("\n") + "\n\n"; + return ( + ["event: completion", `data: ${JSON.stringify(fakeEvent)}`].join("\n") + + "\n\n" + ); } return `data: ${JSON.stringify(fakeEvent)}\n\n`; diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index bbafc84..546b55e 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -22,6 +22,7 @@ declare global { onAborted?: () => void; proceed: () => void; heartbeatInterval?: NodeJS.Timeout; + monitorInterval?: NodeJS.Timeout; promptTokens?: number; outputTokens?: number; tokenizerInfo: Record;