minor refactoring of response middleware handlers

This commit is contained in:
nai-degen 2024-03-17 22:20:39 -05:00
parent 59107af3d6
commit 7c64d9209e
9 changed files with 150 additions and 139 deletions

View File

@ -3,10 +3,10 @@ import http from "http";
import httpProxy from "http-proxy";
import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error";
import { HttpError } from "../../shared/errors";
import { assertNever } from "../../shared/utils";
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
import { sendErrorToClient } from "./response/error-generator";
import { HttpError } from "../../shared/errors";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions";

View File

@ -0,0 +1,76 @@
import util from "util";
import zlib from "zlib";
import { sendProxyError } from "../common";
import type { RawResponseBodyHandler } from "./index";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
br: util.promisify(zlib.brotliDecompress),
};
const isSupportedContentEncoding = (
contentEncoding: string
): contentEncoding is keyof typeof DECODER_MAP => {
return contentEncoding in DECODER_MAP;
};
/**
* Handles the response from the upstream service and decodes the body if
* necessary. If the response is JSON, it will be parsed and returned as an
* object. Otherwise, it will be returned as a string. Does not handle streaming
* responses.
* @throws {Error} Unsupported content-encoding or invalid application/json body
*/
export const handleBlockingResponse: RawResponseBodyHandler = async (
proxyRes,
req,
res
) => {
if (req.isStreaming) {
const err = new Error(
"handleBlockingResponse called for a streaming request."
);
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
throw err;
}
return new Promise<string>((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
let body = Buffer.concat(chunks);
const contentEncoding = proxyRes.headers["content-encoding"];
if (contentEncoding) {
if (isSupportedContentEncoding(contentEncoding)) {
const decoder = DECODER_MAP[contentEncoding];
// @ts-ignore - started failing after upgrading TypeScript, don't care
// as it was never a problem.
body = await decoder(body);
} else {
const error = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, error);
sendProxyError(req, res, 500, "Internal Server Error", {
error,
contentEncoding,
});
return reject(error);
}
}
try {
if (proxyRes.headers["content-type"]?.includes("application/json")) {
const json = JSON.parse(body.toString());
return resolve(json);
}
return resolve(body.toString());
} catch (e) {
const msg = `Proxy received response with invalid JSON: ${e.message}`;
req.log.warn({ error: e.stack, key: req.key?.hash }, msg);
sendProxyError(req, res, 500, "Internal Server Error", { error: msg });
return reject(msg);
}
});
});
};

View File

@ -3,20 +3,21 @@ import { pipeline, Readable, Transform } from "stream";
import StreamArray from "stream-json/streamers/StreamArray";
import { StringDecoder } from "string_decoder";
import { promisify } from "util";
import type { logger } from "../../../logger";
import { BadRequestError, RetryableError } from "../../../shared/errors";
import { APIFormat, keyPool } from "../../../shared/key-management";
import {
copySseResponseHeaders,
initializeSseStream,
} from "../../../shared/streaming";
import type { logger } from "../../../logger";
import { enqueue } from "../../queue";
import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
import { reenqueueRequest } from "../../queue";
import type { RawResponseBodyHandler } from ".";
import { handleBlockingResponse } from "./handle-blocking-response";
import { buildSpoofedSSE, sendErrorToClient } from "./error-generator";
import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder";
import { EventAggregator } from "./streaming/event-aggregator";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { buildSpoofedSSE, sendErrorToClient } from "./error-generator";
import { BadRequestError } from "../../../shared/errors";
const pipelineAsync = promisify(pipeline);
@ -50,7 +51,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
{ statusCode: proxyRes.statusCode, key: hash },
`Streaming request returned error status code. Falling back to non-streaming response handler.`
);
return decodeResponseBody(proxyRes, req, res);
return handleBlockingResponse(proxyRes, req, res);
}
req.log.debug({ headers: proxyRes.headers }, `Starting to proxy SSE stream.`);
@ -105,12 +106,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
} catch (err) {
if (err instanceof RetryableError) {
keyPool.markRateLimited(req.key!);
req.log.warn(
{ key: req.key!.hash, retryCount: req.retryCount },
`Re-enqueueing request due to retryable error during streaming response.`
);
req.retryCount++;
await enqueue(req);
await reenqueueRequest(req);
} else if (err instanceof BadRequestError) {
sendErrorToClient({
req,
@ -138,7 +134,17 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
res.write(`data: [DONE]\n\n`);
res.end();
}
throw err;
// At this point the response is closed. If the request resulted in any
// tokens being consumed (suggesting a mid-stream error), we will resolve
// and continue the middleware chain so tokens can be counted.
if (aggregator.hasEvents()) {
return aggregator.getFinalResponse();
} else {
// If there is nothing, then this was a completely failed prompt that
// will not have billed any tokens. Throw to stop the middleware chain.
throw err;
}
}
};

View File

@ -1,10 +1,8 @@
/* This file is fucking horrendous, sorry */
import { Request, Response } from "express";
import * as http from "http";
import util from "util";
import zlib from "zlib";
import { enqueue, trackWaitTime } from "../../queue";
import { HttpError } from "../../../shared/errors";
import { config } from "../../../config";
import { HttpError, RetryableError } from "../../../shared/errors";
import { keyPool } from "../../../shared/key-management";
import { getOpenAIModelFamily } from "../../../shared/models";
import { countTokens } from "../../../shared/tokenization";
@ -13,6 +11,7 @@ import {
incrementTokenCount,
} from "../../../shared/users/user-store";
import { assertNever } from "../../../shared/utils";
import { reenqueueRequest, trackWaitTime } from "../../queue";
import { refundLastAttempt } from "../../rate-limit";
import {
getCompletionFromBody,
@ -20,39 +19,22 @@ import {
isTextGenerationRequest,
sendProxyError,
} from "../common";
import { handleBlockingResponse } from "./handle-blocking-response";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
import { saveImage } from "./save-image";
import { config } from "../../../config";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
br: util.promisify(zlib.brotliDecompress),
};
const isSupportedContentEncoding = (
contentEncoding: string
): contentEncoding is keyof typeof DECODER_MAP => {
return contentEncoding in DECODER_MAP;
};
export class RetryableError extends Error {
constructor(message: string) {
super(message);
this.name = "RetryableError";
}
}
/**
* Either decodes or streams the entire response body and then passes it as the
* last argument to the rest of the middleware stack.
* Either decodes or streams the entire response body and then resolves with it.
* @returns The response body as a string or parsed JSON object depending on the
* response's content-type.
*/
export type RawResponseBodyHandler = (
proxyRes: http.IncomingMessage,
req: Request,
res: Response
) => Promise<string | Record<string, any>>;
export type ProxyResHandlerWithBody = (
proxyRes: http.IncomingMessage,
req: Request,
@ -76,6 +58,10 @@ export type ProxyResMiddleware = ProxyResHandlerWithBody[];
* middleware from executing as it consumes the stream and forwards events to
* the client. Once the stream is closed, the finalized body will be attached
* to res.body and the remaining middleware will execute.
*
* @param apiMiddleware - Custom middleware to execute after the common response
* handlers. These *only* execute for non-streaming responses, so should be used
* to transform non-streaming responses into the desired format.
*/
export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
return async (
@ -83,30 +69,27 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
req: Request,
res: Response
) => {
const initialHandler = req.isStreaming
const initialHandler: RawResponseBodyHandler = req.isStreaming
? handleStreamedResponse
: decodeResponseBody;
: handleBlockingResponse;
let lastMiddleware = initialHandler.name;
try {
const body = await initialHandler(proxyRes, req, res);
const middlewareStack: ProxyResMiddleware = [];
if (req.isStreaming) {
// `handleStreamedResponse` writes to the response and ends it, so
// we can only execute middleware that doesn't write to the response.
// Handlers for streaming requests must never write to the response.
middlewareStack.push(
trackRateLimit,
trackKeyRateLimit,
countResponseTokens,
incrementUsage,
logPrompt
);
} else {
middlewareStack.push(
trackRateLimit,
addProxyInfo,
trackKeyRateLimit,
injectProxyInfo,
handleUpstreamErrors,
countResponseTokens,
incrementUsage,
@ -154,72 +137,6 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
};
};
async function reenqueueRequest(req: Request) {
req.log.info(
{ key: req.key?.hash, retryCount: req.retryCount },
`Re-enqueueing request due to retryable error`
);
req.retryCount++;
await enqueue(req);
}
/**
* Handles the response from the upstream service and decodes the body if
* necessary. If the response is JSON, it will be parsed and returned as an
* object. Otherwise, it will be returned as a string.
* @throws {Error} Unsupported content-encoding or invalid application/json body
*/
export const decodeResponseBody: RawResponseBodyHandler = async (
proxyRes,
req,
res
) => {
if (req.isStreaming) {
const err = new Error("decodeResponseBody called for a streaming request.");
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
throw err;
}
return new Promise<string>((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
let body = Buffer.concat(chunks);
const contentEncoding = proxyRes.headers["content-encoding"];
if (contentEncoding) {
if (isSupportedContentEncoding(contentEncoding)) {
const decoder = DECODER_MAP[contentEncoding];
// @ts-ignore - started failing after upgrading TypeScript, don't care
// as it was never a problem.
body = await decoder(body);
} else {
const error = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, error);
sendProxyError(req, res, 500, "Internal Server Error", {
error,
contentEncoding,
});
return reject(error);
}
}
try {
if (proxyRes.headers["content-type"]?.includes("application/json")) {
const json = JSON.parse(body.toString());
return resolve(json);
}
return resolve(body.toString());
} catch (e) {
const msg = `Proxy received response with invalid JSON: ${e.message}`;
req.log.warn({ error: e.stack, key: req.key?.hash }, msg);
sendProxyError(req, res, 500, "Internal Server Error", { error: msg });
return reject(msg);
}
});
});
};
type ProxiedErrorPayload = {
error?: Record<string, any>;
message?: string;
@ -242,15 +159,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
) => {
const statusCode = proxyRes.statusCode || 500;
const statusMessage = proxyRes.statusMessage || "Internal Server Error";
if (statusCode < 400) {
return;
}
let errorPayload: ProxiedErrorPayload;
const tryAgainMessage = keyPool.available(req.body?.model)
? `There may be more keys available for this model; try again in a few seconds.`
: "There are no more keys available for this model.";
if (statusCode < 400) return;
try {
assertJsonResponse(body);
@ -303,7 +214,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
// For some reason, some models return this 400 error instead of the
// same 429 billing error that other models return.
await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
await handleOpenAIRateLimitError(req, errorPayload);
} else {
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
}
@ -318,18 +229,18 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (statusCode === 401) {
// Key is invalid or was revoked
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
} else if (statusCode === 403) {
if (service === "anthropic") {
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
return;
}
switch (errorType) {
case "UnrecognizedClientException":
// Key is invalid.
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
break;
case "AccessDeniedException":
const isModelAccessError =
@ -349,7 +260,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (statusCode === 429) {
switch (service) {
case "openai":
await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
await handleOpenAIRateLimitError(req, errorPayload);
break;
case "anthropic":
await handleAnthropicRateLimitError(req, errorPayload);
@ -499,7 +410,6 @@ async function handleAwsRateLimitError(
async function handleOpenAIRateLimitError(
req: Request,
tryAgainMessage: string,
errorPayload: ProxiedErrorPayload
): Promise<Record<string, any>> {
const type = errorPayload.error?.type;
@ -508,17 +418,17 @@ async function handleOpenAIRateLimitError(
case "invalid_request_error": // this is the billing_hard_limit_reached error seen in some cases
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned key's quota has been exceeded. Please try again.`;
break;
case "access_terminated":
// Account banned (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. Please try again.`;
break;
case "billing_not_active":
// Key valid but account billing is delinquent
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key has been disabled due to delinquent billing. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned key has been disabled due to delinquent billing. Please try again.`;
break;
case "requests":
case "tokens":
@ -684,7 +594,7 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
}
};
const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
const trackKeyRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
keyPool.updateRateLimits(req.key!, proxyRes.headers);
};
@ -714,7 +624,7 @@ const copyHttpHeaders: ProxyResHandlerWithBody = async (
* or transformed.
* Only used for non-streaming requests.
*/
const addProxyInfo: ProxyResHandlerWithBody = async (
const injectProxyInfo: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,

View File

@ -67,6 +67,10 @@ export class EventAggregator {
assertNever(this.format);
}
}
hasEvents() {
return this.events.length > 0;
}
}
function eventIsOpenAIEvent(

View File

@ -2,9 +2,8 @@ import pino from "pino";
import { Transform, TransformOptions } from "stream";
import { Message } from "@smithy/eventstream-codec";
import { APIFormat } from "../../../../shared/key-management";
import { RetryableError } from "../index";
import { buildSpoofedSSE } from "../error-generator";
import { BadRequestError } from "../../../../shared/errors";
import { BadRequestError, RetryableError } from "../../../../shared/errors";
type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;

View File

@ -70,7 +70,7 @@ export function generateModelList(models = KNOWN_MISTRAL_AI_MODELS) {
}
const handleModelRequest: RequestHandler = (_req, res) => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60){
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return res.status(200).json(modelsCache);
}
const result = generateModelList();

View File

@ -12,7 +12,7 @@
*/
import crypto from "crypto";
import type { Handler, Request } from "express";
import { Handler, Request } from "express";
import { BadRequestError, TooManyRequestsError } from "../shared/errors";
import { keyPool } from "../shared/key-management";
import {
@ -67,7 +67,7 @@ const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
export async function enqueue(req: Request) {
async function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
let isGuest = req.user?.token === undefined;
@ -136,6 +136,15 @@ export async function enqueue(req: Request) {
}
}
export async function reenqueueRequest(req: Request) {
req.log.info(
{ key: req.key?.hash, retryCount: req.retryCount },
`Re-enqueueing request due to retryable error`
);
req.retryCount++;
await enqueue(req);
}
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue
.filter((req) => getModelFamilyForRequest(req) === partition)

View File

@ -34,3 +34,10 @@ export class TooManyRequestsError extends HttpError {
super(429, message);
}
}
export class RetryableError extends Error {
constructor(message: string) {
super(message);
this.name = "RetryableError";
}
}