improves clarity of errors sent back to streaming clients

This commit is contained in:
nai-degen 2023-10-03 19:45:15 -05:00
parent ba0b20617e
commit 5033d00444
6 changed files with 110 additions and 119 deletions

View File

@ -1,6 +1,8 @@
import { Request, Response } from "express"; import { Request, Response } from "express";
import httpProxy from "http-proxy"; import httpProxy from "http-proxy";
import { ZodError } from "zod"; import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error";
import { buildFakeSse } from "../../shared/streaming";
import { assertNever } from "../../shared/utils"; import { assertNever } from "../../shared/utils";
import { QuotaExceededError } from "./request/apply-quota-limits"; import { QuotaExceededError } from "./request/apply-quota-limits";
@ -44,16 +46,9 @@ export function writeErrorResponse(
res.headersSent || res.headersSent ||
String(res.getHeader("content-type")).startsWith("text/event-stream") String(res.getHeader("content-type")).startsWith("text/event-stream")
) { ) {
const errorContent = const errorTitle = `${errorSource} error (${statusCode})`;
statusCode === 403 const errorContent = JSON.stringify(errorPayload, null, 2);
? JSON.stringify(errorPayload) const msg = buildFakeSse(errorTitle, errorContent, req);
: JSON.stringify(errorPayload, null, 2);
const msg = buildFakeSseMessage(
`${errorSource} error (${statusCode})`,
errorContent,
req
);
res.write(msg); res.write(msg);
res.write(`data: [DONE]\n\n`); res.write(`data: [DONE]\n\n`);
res.end(); res.end();
@ -66,110 +61,98 @@ export function writeErrorResponse(
} }
export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => { export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => {
req.log.error({ err }, `Error during proxy request middleware`); req.log.error(err, `Error during http-proxy-middleware request`);
handleInternalError(err, req as Request, res as Response); classifyErrorAndSend(err, req as Request, res as Response);
}; };
export const handleInternalError = ( export const classifyErrorAndSend = (
err: Error, err: Error,
req: Request, req: Request,
res: Response res: Response
) => { ) => {
try { try {
if (err instanceof ZodError) { const { status, userMessage, ...errorDetails } = classifyError(err);
writeErrorResponse(req, res, 400, { writeErrorResponse(req, res, status, {
error: { error: { message: userMessage, ...errorDetails },
type: "proxy_validation_error", });
proxy_note: `Reverse proxy couldn't validate your request when trying to transform it. Your client may be sending invalid data.`, } catch (error) {
issues: err.issues, req.log.error(error, `Error writing error response headers, giving up.`);
}
};
function classifyError(err: Error): {
/** HTTP status code returned to the client. */
status: number;
/** Message displayed to the user. */
userMessage: string;
/** Short error type, e.g. "proxy_validation_error". */
type: string;
} & Record<string, any> {
const defaultError = {
status: 500,
userMessage: `Reverse proxy encountered an unexpected error. (${err.message})`,
type: "proxy_internal_error",
stack: err.stack, stack: err.stack,
message: err.message, };
switch (err.constructor.name) {
case "ZodError":
const userMessage = generateErrorMessage((err as ZodError).issues, {
prefix: "Request validation failed. ",
path: { enabled: true, label: null, type: "breadcrumbs" },
code: { enabled: false },
maxErrors: 3,
transform: ({ issue, ...rest }) => {
return `At '${rest.pathComponent}', ${issue.message}`;
}, },
}); });
} else if (err.name === "ForbiddenError") { return { status: 400, userMessage, type: "proxy_validation_error" };
// Spoofs a vaguely threatening OpenAI error message. Only invoked by the case "ForbiddenError":
// block-zoomers rewriter to scare off tiktokers. // Mimics a ban notice from OpenAI, thrown when blockZoomerOrigins blocks
writeErrorResponse(req, res, 403, { // a request.
error: { return {
status: 403,
userMessage: `Your account has been disabled for violating our terms of service.`,
type: "organization_account_disabled", type: "organization_account_disabled",
code: "policy_violation", code: "policy_violation",
param: null, };
message: err.message, case "QuotaExceededError":
}, return {
}); status: 429,
} else if (err instanceof QuotaExceededError) { userMessage: `You've exceeded your token quota for this model type.`,
writeErrorResponse(req, res, 429, {
error: {
type: "proxy_quota_exceeded", type: "proxy_quota_exceeded",
code: "quota_exceeded", info: (err as QuotaExceededError).quotaInfo,
message: `You've exceeded your token quota for this model type.`, };
info: err.quotaInfo, case "Error":
stack: err.stack, if ("code" in err) {
}, switch (err.code) {
}); case "ENOTFOUND":
} else { return {
writeErrorResponse(req, res, 500, { status: 502,
error: { userMessage: `Reverse proxy encountered a DNS error while trying to connect to the upstream service.`,
type: "proxy_internal_error", type: "proxy_network_error",
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`, code: err.code,
message: err.message, };
stack: err.stack, case "ECONNREFUSED":
}, return {
}); status: 502,
userMessage: `Reverse proxy couldn't connect to the upstream service.`,
type: "proxy_network_error",
code: err.code,
};
case "ECONNRESET":
return {
status: 504,
userMessage: `Reverse proxy timed out while waiting for the upstream service to respond.`,
type: "proxy_network_error",
code: err.code,
};
} }
} catch (e) {
req.log.error(
{ error: e },
`Error writing error response headers, giving up.`
);
} }
}; return defaultError;
export function buildFakeSseMessage(
type: string,
string: string,
req: Request
) {
let fakeEvent;
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`;
switch (req.inboundApi) {
case "openai":
fakeEvent = {
id: "chatcmpl-" + req.id,
object: "chat.completion.chunk",
created: Date.now(),
model: req.body?.model,
choices: [{ delta: { content }, index: 0, finish_reason: type }],
};
break;
case "openai-text":
fakeEvent = {
id: "cmpl-" + req.id,
object: "text_completion",
created: Date.now(),
choices: [
{ text: content, index: 0, logprobs: null, finish_reason: type },
],
model: req.body?.model,
};
break;
case "anthropic":
fakeEvent = {
completion: content,
stop_reason: type,
truncated: false, // I've never seen this be true
stop: null,
model: req.body?.model,
log_id: "proxy-req-" + req.id,
};
break;
case "google-palm":
throw new Error("PaLM not supported as an inbound API format");
default: default:
assertNever(req.inboundApi); return defaultError;
} }
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
} }
export function getCompletionFromBody(req: Request, body: Record<string, any>) { export function getCompletionFromBody(req: Request, body: Record<string, any>) {

View File

@ -1,5 +1,6 @@
import { RequestHandler } from "express"; import { RequestHandler } from "express";
import { handleInternalError } from "../common"; import { initializeSseStream } from "../../../shared/streaming";
import { classifyErrorAndSend } from "../common";
import { import {
RequestPreprocessor, RequestPreprocessor,
validateContextSize, validateContextSize,
@ -66,6 +67,13 @@ async function executePreprocessors(
next(); next();
} catch (error) { } catch (error) {
req.log.error(error, "Error while executing request preprocessor"); req.log.error(error, "Error while executing request preprocessor");
handleInternalError(error as Error, req, res);
// If the requested has opted into streaming, the client probably won't
// handle a non-eventstream response, but we haven't initialized the SSE
// stream yet as that is typically done later by the request queue. We'll
// do that here and then call classifyErrorAndSend to use the streaming
// error handler.
initializeSseStream(res)
classifyErrorAndSend(error as Error, req, res);
} }
} }

View File

@ -67,12 +67,14 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
} }
const finalMax = Math.min(proxyMax, modelMax); const finalMax = Math.min(proxyMax, modelMax);
z.number() z.object({
tokens: z
.number()
.int() .int()
.max(finalMax, { .max(finalMax, {
message: `Your request exceeds the context size limit for this model or proxy. (max: ${finalMax} tokens, requested: ${promptTokens} prompt + ${outputTokens} output = ${contextTokens} context tokens)`, message: `Your request exceeds the context size limit. (max: ${finalMax} tokens, requested: ${promptTokens} prompt + ${outputTokens} output = ${contextTokens} context tokens)`,
}) }),
.parse(contextTokens); }).parse({ tokens: contextTokens });
req.log.debug( req.log.debug(
{ promptTokens, outputTokens, contextTokens, modelMax, proxyMax }, { promptTokens, outputTokens, contextTokens, modelMax, proxyMax },

View File

@ -1,14 +1,14 @@
import { pipeline } from "stream"; import { pipeline } from "stream";
import { promisify } from "util"; import { promisify } from "util";
import { buildFakeSseMessage } from "../common"; import {
buildFakeSse,
copySseResponseHeaders,
initializeSseStream
} from "../../../shared/streaming";
import { decodeResponseBody, RawResponseBodyHandler } from "."; import { decodeResponseBody, RawResponseBodyHandler } from ".";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer"; import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { EventAggregator } from "./streaming/event-aggregator"; import { EventAggregator } from "./streaming/event-aggregator";
import {
copySseResponseHeaders,
initializeSseStream,
} from "../../../shared/streaming";
const pipelineAsync = promisify(pipeline); const pipelineAsync = promisify(pipeline);
@ -79,7 +79,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
res.end(); res.end();
return aggregator.getFinalResponse(); return aggregator.getFinalResponse();
} catch (err) { } catch (err) {
const errorEvent = buildFakeSseMessage("stream-error", err.message, req); const errorEvent = buildFakeSse("stream-error", err.message, req);
res.write(`${errorEvent}data: [DONE]\n\n`); res.write(`${errorEvent}data: [DONE]\n\n`);
res.end(); res.end();
throw err; throw err;

View File

@ -28,4 +28,3 @@ export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat"; export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text"; export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropic } from "./aggregators/anthropic"; export { mergeEventsForAnthropic } from "./aggregators/anthropic";

View File

@ -23,11 +23,10 @@ import {
getOpenAIModelFamily, getOpenAIModelFamily,
ModelFamily, ModelFamily,
} from "../shared/models"; } from "../shared/models";
import { initializeSseStream } from "../shared/streaming"; import { buildFakeSse, initializeSseStream } from "../shared/streaming";
import { assertNever } from "../shared/utils"; import { assertNever } from "../shared/utils";
import { logger } from "../logger"; import { logger } from "../logger";
import { AGNAI_DOT_CHAT_IP } from "./rate-limit"; import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
import { buildFakeSseMessage } from "./middleware/common";
const queue: Request[] = []; const queue: Request[] = [];
const log = logger.child({ module: "request-queue" }); const log = logger.child({ module: "request-queue" });
@ -109,7 +108,7 @@ export function enqueue(req: Request) {
const avgWait = Math.round(getEstimatedWaitTime(partition) / 1000); const avgWait = Math.round(getEstimatedWaitTime(partition) / 1000);
const currentDuration = Math.round((Date.now() - req.startTime) / 1000); const currentDuration = Math.round((Date.now() - req.startTime) / 1000);
const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`; const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`;
req.res!.write(buildFakeSseMessage("heartbeat", debugMsg, req)); req.res!.write(buildFakeSse("heartbeat", debugMsg, req));
} }
}, 10000); }, 10000);
} }
@ -337,7 +336,7 @@ function killQueuedRequest(req: Request) {
try { try {
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`; const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`;
if (res.headersSent) { if (res.headersSent) {
const fakeErrorEvent = buildFakeSseMessage( const fakeErrorEvent = buildFakeSse(
"proxy queue error", "proxy queue error",
message, message,
req req
@ -363,7 +362,7 @@ function initStreaming(req: Request) {
return; return;
} }
res.write(": joining queue\n\n"); res.write(`: joining queue at position ${queue.length}\n\n`);
} }
/** /**