improves clarity of errors sent back to streaming clients

This commit is contained in:
nai-degen 2023-10-03 19:45:15 -05:00
parent ba0b20617e
commit 5033d00444
6 changed files with 110 additions and 119 deletions

View File

@ -1,6 +1,8 @@
import { Request, Response } from "express"; import { Request, Response } from "express";
import httpProxy from "http-proxy"; import httpProxy from "http-proxy";
import { ZodError } from "zod"; import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error";
import { buildFakeSse } from "../../shared/streaming";
import { assertNever } from "../../shared/utils"; import { assertNever } from "../../shared/utils";
import { QuotaExceededError } from "./request/apply-quota-limits"; import { QuotaExceededError } from "./request/apply-quota-limits";
@ -44,16 +46,9 @@ export function writeErrorResponse(
res.headersSent || res.headersSent ||
String(res.getHeader("content-type")).startsWith("text/event-stream") String(res.getHeader("content-type")).startsWith("text/event-stream")
) { ) {
const errorContent = const errorTitle = `${errorSource} error (${statusCode})`;
statusCode === 403 const errorContent = JSON.stringify(errorPayload, null, 2);
? JSON.stringify(errorPayload) const msg = buildFakeSse(errorTitle, errorContent, req);
: JSON.stringify(errorPayload, null, 2);
const msg = buildFakeSseMessage(
`${errorSource} error (${statusCode})`,
errorContent,
req
);
res.write(msg); res.write(msg);
res.write(`data: [DONE]\n\n`); res.write(`data: [DONE]\n\n`);
res.end(); res.end();
@ -66,110 +61,98 @@ export function writeErrorResponse(
} }
export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => { export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => {
req.log.error({ err }, `Error during proxy request middleware`); req.log.error(err, `Error during http-proxy-middleware request`);
handleInternalError(err, req as Request, res as Response); classifyErrorAndSend(err, req as Request, res as Response);
}; };
export const handleInternalError = ( export const classifyErrorAndSend = (
err: Error, err: Error,
req: Request, req: Request,
res: Response res: Response
) => { ) => {
try { try {
if (err instanceof ZodError) { const { status, userMessage, ...errorDetails } = classifyError(err);
writeErrorResponse(req, res, 400, { writeErrorResponse(req, res, status, {
error: { error: { message: userMessage, ...errorDetails },
type: "proxy_validation_error", });
proxy_note: `Reverse proxy couldn't validate your request when trying to transform it. Your client may be sending invalid data.`, } catch (error) {
issues: err.issues, req.log.error(error, `Error writing error response headers, giving up.`);
stack: err.stack,
message: err.message,
},
});
} else if (err.name === "ForbiddenError") {
// Spoofs a vaguely threatening OpenAI error message. Only invoked by the
// block-zoomers rewriter to scare off tiktokers.
writeErrorResponse(req, res, 403, {
error: {
type: "organization_account_disabled",
code: "policy_violation",
param: null,
message: err.message,
},
});
} else if (err instanceof QuotaExceededError) {
writeErrorResponse(req, res, 429, {
error: {
type: "proxy_quota_exceeded",
code: "quota_exceeded",
message: `You've exceeded your token quota for this model type.`,
info: err.quotaInfo,
stack: err.stack,
},
});
} else {
writeErrorResponse(req, res, 500, {
error: {
type: "proxy_internal_error",
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
message: err.message,
stack: err.stack,
},
});
}
} catch (e) {
req.log.error(
{ error: e },
`Error writing error response headers, giving up.`
);
} }
}; };
export function buildFakeSseMessage( function classifyError(err: Error): {
type: string, /** HTTP status code returned to the client. */
string: string, status: number;
req: Request /** Message displayed to the user. */
) { userMessage: string;
let fakeEvent; /** Short error type, e.g. "proxy_validation_error". */
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`; type: string;
} & Record<string, any> {
const defaultError = {
status: 500,
userMessage: `Reverse proxy encountered an unexpected error. (${err.message})`,
type: "proxy_internal_error",
stack: err.stack,
};
switch (req.inboundApi) { switch (err.constructor.name) {
case "openai": case "ZodError":
fakeEvent = { const userMessage = generateErrorMessage((err as ZodError).issues, {
id: "chatcmpl-" + req.id, prefix: "Request validation failed. ",
object: "chat.completion.chunk", path: { enabled: true, label: null, type: "breadcrumbs" },
created: Date.now(), code: { enabled: false },
model: req.body?.model, maxErrors: 3,
choices: [{ delta: { content }, index: 0, finish_reason: type }], transform: ({ issue, ...rest }) => {
return `At '${rest.pathComponent}', ${issue.message}`;
},
});
return { status: 400, userMessage, type: "proxy_validation_error" };
case "ForbiddenError":
// Mimics a ban notice from OpenAI, thrown when blockZoomerOrigins blocks
// a request.
return {
status: 403,
userMessage: `Your account has been disabled for violating our terms of service.`,
type: "organization_account_disabled",
code: "policy_violation",
}; };
break; case "QuotaExceededError":
case "openai-text": return {
fakeEvent = { status: 429,
id: "cmpl-" + req.id, userMessage: `You've exceeded your token quota for this model type.`,
object: "text_completion", type: "proxy_quota_exceeded",
created: Date.now(), info: (err as QuotaExceededError).quotaInfo,
choices: [
{ text: content, index: 0, logprobs: null, finish_reason: type },
],
model: req.body?.model,
}; };
break; case "Error":
case "anthropic": if ("code" in err) {
fakeEvent = { switch (err.code) {
completion: content, case "ENOTFOUND":
stop_reason: type, return {
truncated: false, // I've never seen this be true status: 502,
stop: null, userMessage: `Reverse proxy encountered a DNS error while trying to connect to the upstream service.`,
model: req.body?.model, type: "proxy_network_error",
log_id: "proxy-req-" + req.id, code: err.code,
}; };
break; case "ECONNREFUSED":
case "google-palm": return {
throw new Error("PaLM not supported as an inbound API format"); status: 502,
userMessage: `Reverse proxy couldn't connect to the upstream service.`,
type: "proxy_network_error",
code: err.code,
};
case "ECONNRESET":
return {
status: 504,
userMessage: `Reverse proxy timed out while waiting for the upstream service to respond.`,
type: "proxy_network_error",
code: err.code,
};
}
}
return defaultError;
default: default:
assertNever(req.inboundApi); return defaultError;
} }
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
} }
export function getCompletionFromBody(req: Request, body: Record<string, any>) { export function getCompletionFromBody(req: Request, body: Record<string, any>) {

View File

@ -1,5 +1,6 @@
import { RequestHandler } from "express"; import { RequestHandler } from "express";
import { handleInternalError } from "../common"; import { initializeSseStream } from "../../../shared/streaming";
import { classifyErrorAndSend } from "../common";
import { import {
RequestPreprocessor, RequestPreprocessor,
validateContextSize, validateContextSize,
@ -66,6 +67,13 @@ async function executePreprocessors(
next(); next();
} catch (error) { } catch (error) {
req.log.error(error, "Error while executing request preprocessor"); req.log.error(error, "Error while executing request preprocessor");
handleInternalError(error as Error, req, res);
// If the requested has opted into streaming, the client probably won't
// handle a non-eventstream response, but we haven't initialized the SSE
// stream yet as that is typically done later by the request queue. We'll
// do that here and then call classifyErrorAndSend to use the streaming
// error handler.
initializeSseStream(res)
classifyErrorAndSend(error as Error, req, res);
} }
} }

View File

@ -67,12 +67,14 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
} }
const finalMax = Math.min(proxyMax, modelMax); const finalMax = Math.min(proxyMax, modelMax);
z.number() z.object({
.int() tokens: z
.max(finalMax, { .number()
message: `Your request exceeds the context size limit for this model or proxy. (max: ${finalMax} tokens, requested: ${promptTokens} prompt + ${outputTokens} output = ${contextTokens} context tokens)`, .int()
}) .max(finalMax, {
.parse(contextTokens); message: `Your request exceeds the context size limit. (max: ${finalMax} tokens, requested: ${promptTokens} prompt + ${outputTokens} output = ${contextTokens} context tokens)`,
}),
}).parse({ tokens: contextTokens });
req.log.debug( req.log.debug(
{ promptTokens, outputTokens, contextTokens, modelMax, proxyMax }, { promptTokens, outputTokens, contextTokens, modelMax, proxyMax },

View File

@ -1,14 +1,14 @@
import { pipeline } from "stream"; import { pipeline } from "stream";
import { promisify } from "util"; import { promisify } from "util";
import { buildFakeSseMessage } from "../common"; import {
buildFakeSse,
copySseResponseHeaders,
initializeSseStream
} from "../../../shared/streaming";
import { decodeResponseBody, RawResponseBodyHandler } from "."; import { decodeResponseBody, RawResponseBodyHandler } from ".";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer"; import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { EventAggregator } from "./streaming/event-aggregator"; import { EventAggregator } from "./streaming/event-aggregator";
import {
copySseResponseHeaders,
initializeSseStream,
} from "../../../shared/streaming";
const pipelineAsync = promisify(pipeline); const pipelineAsync = promisify(pipeline);
@ -79,7 +79,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
res.end(); res.end();
return aggregator.getFinalResponse(); return aggregator.getFinalResponse();
} catch (err) { } catch (err) {
const errorEvent = buildFakeSseMessage("stream-error", err.message, req); const errorEvent = buildFakeSse("stream-error", err.message, req);
res.write(`${errorEvent}data: [DONE]\n\n`); res.write(`${errorEvent}data: [DONE]\n\n`);
res.end(); res.end();
throw err; throw err;

View File

@ -28,4 +28,3 @@ export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat"; export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text"; export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropic } from "./aggregators/anthropic"; export { mergeEventsForAnthropic } from "./aggregators/anthropic";

View File

@ -23,11 +23,10 @@ import {
getOpenAIModelFamily, getOpenAIModelFamily,
ModelFamily, ModelFamily,
} from "../shared/models"; } from "../shared/models";
import { initializeSseStream } from "../shared/streaming"; import { buildFakeSse, initializeSseStream } from "../shared/streaming";
import { assertNever } from "../shared/utils"; import { assertNever } from "../shared/utils";
import { logger } from "../logger"; import { logger } from "../logger";
import { AGNAI_DOT_CHAT_IP } from "./rate-limit"; import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
import { buildFakeSseMessage } from "./middleware/common";
const queue: Request[] = []; const queue: Request[] = [];
const log = logger.child({ module: "request-queue" }); const log = logger.child({ module: "request-queue" });
@ -109,7 +108,7 @@ export function enqueue(req: Request) {
const avgWait = Math.round(getEstimatedWaitTime(partition) / 1000); const avgWait = Math.round(getEstimatedWaitTime(partition) / 1000);
const currentDuration = Math.round((Date.now() - req.startTime) / 1000); const currentDuration = Math.round((Date.now() - req.startTime) / 1000);
const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`; const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`;
req.res!.write(buildFakeSseMessage("heartbeat", debugMsg, req)); req.res!.write(buildFakeSse("heartbeat", debugMsg, req));
} }
}, 10000); }, 10000);
} }
@ -337,7 +336,7 @@ function killQueuedRequest(req: Request) {
try { try {
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`; const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`;
if (res.headersSent) { if (res.headersSent) {
const fakeErrorEvent = buildFakeSseMessage( const fakeErrorEvent = buildFakeSse(
"proxy queue error", "proxy queue error",
message, message,
req req
@ -363,7 +362,7 @@ function initStreaming(req: Request) {
return; return;
} }
res.write(": joining queue\n\n"); res.write(`: joining queue at position ${queue.length}\n\n`);
} }
/** /**