Scale SSE heartbeat size with traffic (khanon/oai-reverse-proxy!53)
This commit is contained in:
parent
6acdf35914
commit
6aa6bebf08
|
@ -11,6 +11,7 @@
|
||||||
* back in the queue and it will be retried later using the same closure.
|
* back in the queue and it will be retried later using the same closure.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import crypto from "crypto";
|
||||||
import type { Handler, Request } from "express";
|
import type { Handler, Request } from "express";
|
||||||
import { keyPool } from "../shared/key-management";
|
import { keyPool } from "../shared/key-management";
|
||||||
import {
|
import {
|
||||||
|
@ -23,7 +24,7 @@ import {
|
||||||
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
|
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
|
||||||
import { assertNever } from "../shared/utils";
|
import { assertNever } from "../shared/utils";
|
||||||
import { logger } from "../logger";
|
import { logger } from "../logger";
|
||||||
import { SHARED_IP_ADDRESSES } from "./rate-limit";
|
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||||
import { RequestPreprocessor } from "./middleware/request";
|
import { RequestPreprocessor } from "./middleware/request";
|
||||||
|
|
||||||
const queue: Request[] = [];
|
const queue: Request[] = [];
|
||||||
|
@ -33,6 +34,15 @@ const log = logger.child({ module: "request-queue" });
|
||||||
const AGNAI_CONCURRENCY_LIMIT = 5;
|
const AGNAI_CONCURRENCY_LIMIT = 5;
|
||||||
/** Maximum number of queue slots for individual users. */
|
/** Maximum number of queue slots for individual users. */
|
||||||
const USER_CONCURRENCY_LIMIT = 1;
|
const USER_CONCURRENCY_LIMIT = 1;
|
||||||
|
const MIN_HEARTBEAT_SIZE = 512;
|
||||||
|
const MAX_HEARTBEAT_SIZE =
|
||||||
|
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
|
||||||
|
const HEARTBEAT_INTERVAL =
|
||||||
|
1000 * parseInt(process.env.HEARTBEAT_INTERVAL_SEC ?? "5");
|
||||||
|
const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "50");
|
||||||
|
const PAYLOAD_SCALE_FACTOR = parseFloat(
|
||||||
|
process.env.PAYLOAD_SCALE_FACTOR ?? "6"
|
||||||
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns an identifier for a request. This is used to determine if a
|
* Returns an identifier for a request. This is used to determine if a
|
||||||
|
@ -93,31 +103,21 @@ export function enqueue(req: Request) {
|
||||||
if (!res.headersSent) {
|
if (!res.headersSent) {
|
||||||
initStreaming(req);
|
initStreaming(req);
|
||||||
}
|
}
|
||||||
req.heartbeatInterval = setInterval(() => {
|
registerHeartbeat(req);
|
||||||
if (process.env.NODE_ENV === "production") {
|
} else if (getProxyLoad() > LOAD_THRESHOLD) {
|
||||||
if (!req.query.badSseParser) req.res!.write(": queue heartbeat\n\n");
|
throw new Error(
|
||||||
} else {
|
"Due to heavy traffic on this proxy, you must enable streaming for your request."
|
||||||
req.log.info(`Sending heartbeat to request in queue.`);
|
);
|
||||||
const partition = getPartitionForRequest(req);
|
|
||||||
const avgWait = Math.round(getEstimatedWaitTime(partition) / 1000);
|
|
||||||
const currentDuration = Math.round((Date.now() - req.startTime) / 1000);
|
|
||||||
const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`;
|
|
||||||
req.res!.write(buildFakeSse("heartbeat", debugMsg, req));
|
|
||||||
}
|
|
||||||
}, 10000);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register a handler to remove the request from the queue if the connection
|
|
||||||
// is aborted or closed before it is dequeued.
|
|
||||||
const removeFromQueue = () => {
|
const removeFromQueue = () => {
|
||||||
req.log.info(`Removing aborted request from queue.`);
|
req.log.info(`Removing aborted request from queue.`);
|
||||||
const index = queue.indexOf(req);
|
const index = queue.indexOf(req);
|
||||||
if (index !== -1) {
|
if (index !== -1) {
|
||||||
queue.splice(index, 1);
|
queue.splice(index, 1);
|
||||||
}
|
}
|
||||||
if (req.heartbeatInterval) {
|
if (req.heartbeatInterval) clearInterval(req.heartbeatInterval);
|
||||||
clearInterval(req.heartbeatInterval);
|
if (req.monitorInterval) clearInterval(req.monitorInterval);
|
||||||
}
|
|
||||||
};
|
};
|
||||||
req.onAborted = removeFromQueue;
|
req.onAborted = removeFromQueue;
|
||||||
req.res!.once("close", removeFromQueue);
|
req.res!.once("close", removeFromQueue);
|
||||||
|
@ -188,9 +188,8 @@ export function dequeue(partition: ModelFamily): Request | undefined {
|
||||||
req.onAborted = undefined;
|
req.onAborted = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (req.heartbeatInterval) {
|
if (req.heartbeatInterval) clearInterval(req.heartbeatInterval);
|
||||||
clearInterval(req.heartbeatInterval);
|
if (req.monitorInterval) clearInterval(req.monitorInterval);
|
||||||
}
|
|
||||||
|
|
||||||
// Track the time leaving the queue now, but don't add it to the wait times
|
// Track the time leaving the queue now, but don't add it to the wait times
|
||||||
// yet because we don't know if the request will succeed or fail. We track
|
// yet because we don't know if the request will succeed or fail. We track
|
||||||
|
@ -385,6 +384,7 @@ export function createQueueMiddleware({
|
||||||
function killQueuedRequest(req: Request) {
|
function killQueuedRequest(req: Request) {
|
||||||
if (!req.res || req.res.writableEnded) {
|
if (!req.res || req.res.writableEnded) {
|
||||||
req.log.warn(`Attempted to terminate request that has already ended.`);
|
req.log.warn(`Attempted to terminate request that has already ended.`);
|
||||||
|
queue.splice(queue.indexOf(req), 1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const res = req.res;
|
const res = req.res;
|
||||||
|
@ -469,3 +469,84 @@ function removeProxyMiddlewareEventListeners(req: Request) {
|
||||||
req.removeListener("error", reqOnError as any);
|
req.removeListener("error", reqOnError as any);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function registerHeartbeat(req: Request) {
|
||||||
|
const res = req.res!;
|
||||||
|
|
||||||
|
let isBufferFull = false;
|
||||||
|
let bufferFullCount = 0;
|
||||||
|
req.heartbeatInterval = setInterval(() => {
|
||||||
|
if (isBufferFull) {
|
||||||
|
bufferFullCount++;
|
||||||
|
if (bufferFullCount >= 3) {
|
||||||
|
req.log.error("Heartbeat skipped too many times; killing connection.");
|
||||||
|
res.destroy();
|
||||||
|
} else {
|
||||||
|
req.log.warn({ bufferFullCount }, "Heartbeat skipped; buffer is full.");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = getHeartbeatPayload();
|
||||||
|
if (!res.write(data)) {
|
||||||
|
isBufferFull = true;
|
||||||
|
res.once("drain", () => (isBufferFull = false));
|
||||||
|
}
|
||||||
|
}, HEARTBEAT_INTERVAL);
|
||||||
|
monitorHeartbeat(req);
|
||||||
|
}
|
||||||
|
|
||||||
|
function monitorHeartbeat(req: Request) {
|
||||||
|
const res = req.res!;
|
||||||
|
|
||||||
|
let lastBytesSent = 0;
|
||||||
|
req.monitorInterval = setInterval(() => {
|
||||||
|
const bytesSent = res.socket?.bytesWritten ?? 0;
|
||||||
|
const bytesSinceLast = bytesSent - lastBytesSent;
|
||||||
|
req.log.debug(
|
||||||
|
{
|
||||||
|
previousBytesSent: lastBytesSent,
|
||||||
|
currentBytesSent: bytesSent,
|
||||||
|
},
|
||||||
|
"Heartbeat monitor check."
|
||||||
|
);
|
||||||
|
lastBytesSent = bytesSent;
|
||||||
|
|
||||||
|
const minBytes = Math.floor(getHeartbeatSize() / 2);
|
||||||
|
if (bytesSinceLast < minBytes) {
|
||||||
|
req.log.warn(
|
||||||
|
{ minBytes, bytesSinceLast },
|
||||||
|
"Queued request is processing heartbeats enough data or server is overloaded; killing connection."
|
||||||
|
);
|
||||||
|
res.destroy();
|
||||||
|
}
|
||||||
|
}, HEARTBEAT_INTERVAL * 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sends larger heartbeats when the queue is overloaded */
|
||||||
|
function getHeartbeatSize() {
|
||||||
|
const load = getProxyLoad();
|
||||||
|
|
||||||
|
if (load <= LOAD_THRESHOLD) {
|
||||||
|
return MIN_HEARTBEAT_SIZE;
|
||||||
|
} else {
|
||||||
|
const excessLoad = load - LOAD_THRESHOLD;
|
||||||
|
const size =
|
||||||
|
MIN_HEARTBEAT_SIZE + Math.pow(excessLoad * PAYLOAD_SCALE_FACTOR, 2);
|
||||||
|
if (size > MAX_HEARTBEAT_SIZE) return MAX_HEARTBEAT_SIZE;
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getHeartbeatPayload() {
|
||||||
|
const size = getHeartbeatSize();
|
||||||
|
const data =
|
||||||
|
process.env.NODE_ENV === "production"
|
||||||
|
? crypto.randomBytes(size)
|
||||||
|
: `payload size: ${size}`;
|
||||||
|
return `: queue heartbeat ${data}\n\n`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getProxyLoad() {
|
||||||
|
return Math.max(getUniqueIps(), queue.length);
|
||||||
|
}
|
||||||
|
|
|
@ -39,11 +39,7 @@ export function copySseResponseHeaders(
|
||||||
* that the request is being proxied to. Used to send error messages to the
|
* that the request is being proxied to. Used to send error messages to the
|
||||||
* client in the middle of a streaming request.
|
* client in the middle of a streaming request.
|
||||||
*/
|
*/
|
||||||
export function buildFakeSse(
|
export function buildFakeSse(type: string, string: string, req: Request) {
|
||||||
type: string,
|
|
||||||
string: string,
|
|
||||||
req: Request
|
|
||||||
) {
|
|
||||||
let fakeEvent;
|
let fakeEvent;
|
||||||
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`;
|
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`;
|
||||||
|
|
||||||
|
@ -54,7 +50,7 @@ export function buildFakeSse(
|
||||||
object: "chat.completion.chunk",
|
object: "chat.completion.chunk",
|
||||||
created: Date.now(),
|
created: Date.now(),
|
||||||
model: req.body?.model,
|
model: req.body?.model,
|
||||||
choices: [{ delta: { content }, index: 0, finish_reason: type }]
|
choices: [{ delta: { content }, index: 0, finish_reason: type }],
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
|
@ -63,9 +59,9 @@ export function buildFakeSse(
|
||||||
object: "text_completion",
|
object: "text_completion",
|
||||||
created: Date.now(),
|
created: Date.now(),
|
||||||
choices: [
|
choices: [
|
||||||
{ text: content, index: 0, logprobs: null, finish_reason: type }
|
{ text: content, index: 0, logprobs: null, finish_reason: type },
|
||||||
],
|
],
|
||||||
model: req.body?.model
|
model: req.body?.model,
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
|
@ -75,7 +71,7 @@ export function buildFakeSse(
|
||||||
truncated: false, // I've never seen this be true
|
truncated: false, // I've never seen this be true
|
||||||
stop: null,
|
stop: null,
|
||||||
model: req.body?.model,
|
model: req.body?.model,
|
||||||
log_id: "proxy-req-" + req.id
|
log_id: "proxy-req-" + req.id,
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case "google-palm":
|
case "google-palm":
|
||||||
|
@ -86,10 +82,10 @@ export function buildFakeSse(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (req.inboundApi === "anthropic") {
|
if (req.inboundApi === "anthropic") {
|
||||||
return [
|
return (
|
||||||
"event: completion",
|
["event: completion", `data: ${JSON.stringify(fakeEvent)}`].join("\n") +
|
||||||
`data: ${JSON.stringify(fakeEvent)}`,
|
"\n\n"
|
||||||
].join("\n") + "\n\n";
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||||
|
|
|
@ -22,6 +22,7 @@ declare global {
|
||||||
onAborted?: () => void;
|
onAborted?: () => void;
|
||||||
proceed: () => void;
|
proceed: () => void;
|
||||||
heartbeatInterval?: NodeJS.Timeout;
|
heartbeatInterval?: NodeJS.Timeout;
|
||||||
|
monitorInterval?: NodeJS.Timeout;
|
||||||
promptTokens?: number;
|
promptTokens?: number;
|
||||||
outputTokens?: number;
|
outputTokens?: number;
|
||||||
tokenizerInfo: Record<string, any>;
|
tokenizerInfo: Record<string, any>;
|
||||||
|
|
Loading…
Reference in New Issue