improves error handling for sillytavern

This commit is contained in:
nai-degen 2024-03-04 22:54:21 -06:00
parent 068e7a834f
commit 03c5c473e1
14 changed files with 499 additions and 240 deletions

View File

@ -1,4 +1,4 @@
import { Request, RequestHandler, Router } from "express";
import { Request, Response, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { logger } from "../logger";
@ -17,6 +17,7 @@ import {
createOnProxyResHandler,
} from "./middleware/response";
import { HttpError } from "../shared/errors";
import { sendErrorToClient } from "./middleware/response/error-generator";
const CLAUDE_3_COMPAT_MODEL =
process.env.CLAUDE_3_COMPAT_MODEL || "claude-3-sonnet-20240229";
@ -251,16 +252,19 @@ anthropicRouter.post(
"/v1/claude-3/complete",
ipLimiter,
handleCompatibilityRequest,
createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "anthropic" },
{
beforeTransform: [(req) => void (req.body.model = CLAUDE_3_COMPAT_MODEL)],
}
),
createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-chat",
service: "anthropic",
}),
anthropicProxy
);
export function handleCompatibilityRequest(req: Request, res: any, next: any) {
export function handleCompatibilityRequest(
req: Request,
res: Response,
next: any
) {
const alreadyInChatFormat = Boolean(req.body.messages);
const alreadyUsingClaude3 = req.body.model?.includes("claude-3");
if (!alreadyInChatFormat && !alreadyUsingClaude3) {
@ -268,18 +272,24 @@ export function handleCompatibilityRequest(req: Request, res: any, next: any) {
}
if (alreadyInChatFormat) {
throw new HttpError(
400,
"Your request is already using the new API format and does not need the compatibility endpoint. Use the /proxy/anthropic endpoint instead."
);
sendErrorToClient({
req,
res,
options: {
title: "Proxy error (incompatible request for endpoint)",
message:
"Your request is already using the new API format and does not need to use the compatibility endpoint.\n\nUse the /proxy/anthropic endpoint instead.",
format: "unknown",
statusCode: 400,
reqId: req.id,
},
});
}
if (alreadyUsingClaude3) {
throw new HttpError(
400,
"Your request already includes the new model identifier and does not need the compatibility endpoint. Use the /proxy/anthropic endpoint instead."
);
if (!alreadyUsingClaude3) {
req.body.model = CLAUDE_3_COMPAT_MODEL;
}
next();
}
function maybeReassignModel(req: Request) {

View File

@ -2,9 +2,9 @@ import { Request, Response } from "express";
import httpProxy from "http-proxy";
import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error";
import { makeCompletionSSE } from "../../shared/streaming";
import { assertNever } from "../../shared/utils";
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
import { buildSpoofedSSE, sendErrorToClient } from "./response/error-generator";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions";
@ -40,7 +40,7 @@ export function isEmbeddingsRequest(req: Request) {
);
}
export function writeErrorResponse(
export function sendProxyError(
req: Request,
res: Response,
statusCode: number,
@ -52,29 +52,22 @@ export function writeErrorResponse(
? `The proxy encountered an error while trying to process your prompt.`
: `The proxy encountered an error while trying to send your prompt to the upstream service.`;
// If we're mid-SSE stream, send a data event with the error payload and end
// the stream. Otherwise just send a normal error response.
if (
res.headersSent ||
String(res.getHeader("content-type")).startsWith("text/event-stream")
) {
const event = makeCompletionSSE({
if (req.tokenizerInfo && typeof errorPayload.error === "object") {
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
}
sendErrorToClient({
options: {
format: req.inboundApi,
title: `Proxy error (HTTP ${statusCode} ${statusMessage})`,
message: `${msg} Further technical details are provided below.`,
obj: errorPayload,
reqId: req.id,
model: req.body?.model,
});
res.write(event);
res.write(`data: [DONE]\n\n`);
res.end();
} else {
if (req.tokenizerInfo && typeof errorPayload.error === "object") {
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
}
res.status(statusCode).json(errorPayload);
}
},
req,
res,
});
}
export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => {
@ -90,7 +83,7 @@ export const classifyErrorAndSend = (
try {
const { statusCode, statusMessage, userMessage, ...errorDetails } =
classifyError(err);
writeErrorResponse(req, res, statusCode, statusMessage, {
sendProxyError(req, res, statusCode, statusMessage, {
error: { message: userMessage, ...errorDetails },
});
} catch (error) {

View File

@ -122,6 +122,7 @@ const handleTestMessage: RequestHandler = (req, res) => {
object: "chat.completion",
created: Date.now(),
model: body.model,
// openai chat
choices: [
{
message: { role: "assistant", content: "Hello!" },
@ -129,6 +130,10 @@ const handleTestMessage: RequestHandler = (req, res) => {
index: 0,
},
],
// anthropic text
completion: "Hello!",
// anthropic chat
content: [{ type: "text", text: "Hello!" }],
proxy_note:
"This response was generated by the proxy's test message handler and did not go to the API.",
});

View File

@ -2,7 +2,7 @@ import { Request } from "express";
import { config } from "../../../../config";
import { assertNever } from "../../../../shared/utils";
import { RequestPreprocessor } from "../index";
import { UserInputError } from "../../../../shared/errors";
import { BadRequestError } from "../../../../shared/errors";
import {
MistralAIChatMessage,
OpenAIChatMessage,
@ -46,7 +46,7 @@ export const languageFilter: RequestPreprocessor = async (req) => {
req.res!.once("close", resolve);
setTimeout(resolve, delay);
});
throw new UserInputError(config.rejectMessage);
throw new BadRequestError(config.rejectMessage);
}
};

View File

@ -0,0 +1,352 @@
import express from "express";
import { APIFormat } from "../../../shared/key-management";
import { assertNever } from "../../../shared/utils";
import { initializeSseStream } from "../../../shared/streaming";
function getMessageContent({
title,
message,
obj,
}: {
title: string;
message: string;
obj?: Record<string, any>;
}) {
/*
Constructs a Markdown-formatted message that renders semi-nicely in most chat
frontends. For example:
**Proxy error (HTTP 404 Not Found)**
The proxy encountered an error while trying to send your prompt to the upstream service. Further technical details are provided below.
***
*The requested Claude model might not exist, or the key might not be provisioned for it.*
```
{
"type": "error",
"error": {
"type": "not_found_error",
"message": "model: some-invalid-model-id",
"proxy_tokenizer": {
"tokenizer": "@anthropic-ai/tokenizer",
"token_count": 6104,
"tokenization_duration_ms": 4.0765,
"prompt_tokens": 6104,
"completion_tokens": 30,
"max_model_tokens": 200000,
"max_proxy_tokens": 9007199254740991
}
},
"proxy_note": "The requested Claude model might not exist, or the key might not be provisioned for it."
}
```
*/
const friendlyMessage = obj?.proxy_note
? `${message}\n\n***\n\n*${obj.proxy_note}*`
: message;
const details = JSON.parse(JSON.stringify(obj ?? {}));
let stack = "";
if (details.stack) {
stack = `\n\nInclude this trace when reporting an issue.\n\`\`\`\n${details.stack}\n\`\`\``;
delete details.stack;
}
return `\n\n**${title}**\n${friendlyMessage}${
obj ? `\n\`\`\`\n${JSON.stringify(obj, null, 2)}\n\`\`\`\n${stack}` : ""
}`;
}
type ErrorGeneratorOptions = {
format: APIFormat | "unknown";
title: string;
message: string;
obj?: object;
reqId: string | number | object;
model?: string;
statusCode?: number;
};
export function tryInferFormat(body: any): APIFormat | "unknown" {
if (typeof body !== "object" || !body.model) {
return "unknown";
}
if (body.model.includes("gpt")) {
return "openai";
}
if (body.model.includes("mistral")) {
return "mistral-ai";
}
if (body.model.includes("claude")) {
return body.messages?.length ? "anthropic-chat" : "anthropic-text";
}
if (body.model.includes("gemini")) {
return "google-ai";
}
return "unknown";
}
export function sendErrorToClient({
options,
req,
res,
}: {
options: ErrorGeneratorOptions;
req: express.Request;
res: express.Response;
}) {
const { format: inputFormat } = options;
// This is an error thrown before we know the format of the request, so we
// can't send a response in the format the client expects.
const format =
inputFormat === "unknown" ? tryInferFormat(req.body) : inputFormat;
if (format === "unknown") {
return res.status(options.statusCode || 400).json({
error: options.message,
details: options.obj,
});
}
const completion = buildSpoofedCompletion({ ...options, format });
const event = buildSpoofedSSE({ ...options, format });
const isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
if (isStreaming) {
if (!res.headersSent) {
initializeSseStream(res);
}
res.write(event);
res.write(`data: [DONE]\n\n`);
res.end();
} else {
res.status(200).json(completion);
}
}
/**
* Returns a non-streaming completion object that looks like it came from the
* service that the request is being proxied to. Used to send error messages to
* the client and have them look like normal responses, for clients with poor
* error handling.
*/
export function buildSpoofedCompletion({
format,
title,
message,
obj,
reqId,
model = "unknown",
}: ErrorGeneratorOptions & { format: Exclude<APIFormat, "unknown"> }) {
const id = String(reqId);
const content = getMessageContent({ title, message, obj });
switch (format) {
case "openai":
case "mistral-ai":
return {
id: "error-" + id,
object: "chat.completion",
created: Date.now(),
model,
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
choices: [
{
message: { role: "assistant", content },
finish_reason: title,
index: 0,
},
],
};
case "openai-text":
return {
id: "error-" + id,
object: "text_completion",
created: Date.now(),
model,
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
choices: [
{ text: content, index: 0, logprobs: null, finish_reason: title },
],
};
case "anthropic-text":
return {
id: "error-" + id,
type: "completion",
completion: content,
stop_reason: title,
stop: null,
model,
};
case "anthropic-chat":
return {
id: "error-" + id,
type: "message",
role: "assistant",
content: [{ type: "text", text: content }],
model,
stop_reason: title,
stop_sequence: null,
};
case "google-ai":
// TODO: Native Google AI non-streaming responses are not supported, this
// is an untested guess at what the response should look like.
return {
id: "error-" + id,
object: "chat.completion",
created: Date.now(),
model,
candidates: [
{
content: { parts: [{ text: content }], role: "model" },
finishReason: title,
index: 0,
tokenCount: null,
safetyRatings: [],
},
],
};
case "openai-image":
throw new Error(
`Spoofed completions not supported for ${format} requests`
);
default:
assertNever(format);
}
}
/**
* Returns an SSE message that looks like a completion event for the service
* that the request is being proxied to. Used to send error messages to the
* client in the middle of a streaming request.
*/
export function buildSpoofedSSE({
format,
title,
message,
obj,
reqId,
model = "unknown",
}: ErrorGeneratorOptions & { format: Exclude<APIFormat, "unknown"> }) {
const id = String(reqId);
const content = getMessageContent({ title, message, obj });
let event;
switch (format) {
case "openai":
case "mistral-ai":
event = {
id: "chatcmpl-" + id,
object: "chat.completion.chunk",
created: Date.now(),
model,
choices: [{ delta: { content }, index: 0, finish_reason: title }],
};
break;
case "openai-text":
event = {
id: "cmpl-" + id,
object: "text_completion",
created: Date.now(),
choices: [
{ text: content, index: 0, logprobs: null, finish_reason: title },
],
model,
};
break;
case "anthropic-text":
event = {
completion: content,
stop_reason: title,
truncated: false,
stop: null,
model,
log_id: "proxy-req-" + id,
};
break;
case "anthropic-chat":
event = {
type: "content_block_delta",
index: 0,
delta: { type: "text_delta", text: content },
};
break;
case "google-ai":
return JSON.stringify({
candidates: [
{
content: { parts: [{ text: content }], role: "model" },
finishReason: title,
index: 0,
tokenCount: null,
safetyRatings: [],
},
],
});
case "openai-image":
throw new Error(`SSE not supported for ${format} requests`);
default:
assertNever(format);
}
if (format === "anthropic-text") {
return (
["event: completion", `data: ${JSON.stringify(event)}`].join("\n") +
"\n\n"
);
}
// ugh.
if (format === "anthropic-chat") {
return (
[
[
"event: message_start",
`data: ${JSON.stringify({
type: "message_start",
message: {
id: "error-" + id,
type: "message",
role: "assistant",
content: [],
model,
},
})}`,
].join("\n"),
[
"event: content_block_start",
`data: ${JSON.stringify({
type: "content_block_start",
index: 0,
content_block: { type: "text", text: "" },
})}`,
].join("\n"),
["event: content_block_delta", `data: ${JSON.stringify(event)}`].join(
"\n"
),
[
"event: content_block_stop",
`data: ${JSON.stringify({ type: "content_block_stop", index: 0 })}`,
].join("\n"),
[
"event: message_delta",
`data: ${JSON.stringify({
type: "message_delta",
delta: { stop_reason: title, stop_sequence: null, usage: null },
})}`,
],
[
"event: message_stop",
`data: ${JSON.stringify({ type: "message_stop" })}`,
].join("\n"),
].join("\n\n") + "\n\n"
);
}
return `data: ${JSON.stringify(event)}\n\n`;
}

View File

@ -6,7 +6,7 @@ import { APIFormat, keyPool } from "../../../shared/key-management";
import {
copySseResponseHeaders,
initializeSseStream,
makeCompletionSSE,
} from "../../../shared/streaming";
import type { logger } from "../../../logger";
import { enqueue } from "../../queue";
@ -15,6 +15,7 @@ import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder";
import { EventAggregator } from "./streaming/event-aggregator";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { buildSpoofedSSE } from "./error-generator";
const pipelineAsync = promisify(pipeline);
@ -111,7 +112,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
} else {
const { message, stack, lastEvent } = err;
const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined";
const errorEvent = makeCompletionSSE({
const errorEvent = buildSpoofedSSE({
format: req.inboundApi,
title: "Proxy stream error",
message: "An unexpected error occurred while streaming the response.",

View File

@ -18,7 +18,7 @@ import {
getCompletionFromBody,
isImageGenerationRequest,
isTextGenerationRequest,
writeErrorResponse,
sendProxyError,
} from "../common";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
@ -192,13 +192,13 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
// as it was never a problem.
body = await decoder(body);
} else {
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
writeErrorResponse(req, res, 500, "Internal Server Error", {
error: errorMessage,
const error = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, error);
sendProxyError(req, res, 500, "Internal Server Error", {
error,
contentEncoding,
});
return reject(errorMessage);
return reject(error);
}
}
@ -208,13 +208,11 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
return resolve(json);
}
return resolve(body.toString());
} catch (error: any) {
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
req.log.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
writeErrorResponse(req, res, 500, "Internal Server Error", {
error: errorMessage,
});
return reject(errorMessage);
} catch (e) {
const msg = `Proxy received response with invalid JSON: ${e.message}`;
req.log.warn({ error: e.stack, key: req.key?.hash }, msg);
sendProxyError(req, res, 500, "Internal Server Error", { error: msg });
return reject(msg);
}
});
});
@ -267,7 +265,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service.`,
};
writeErrorResponse(req, res, statusCode, statusMessage, errorObject);
sendProxyError(req, res, statusCode, statusMessage, errorObject);
throw new HttpError(statusCode, parseError.message);
}
@ -412,7 +410,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
);
}
writeErrorResponse(req, res, statusCode, statusMessage, errorPayload);
sendProxyError(req, res, statusCode, statusMessage, errorPayload);
// This is bubbled up to onProxyRes's handler for logging but will not trigger
// a write to the response as `sendProxyError` has just done that.
throw new HttpError(statusCode, errorPayload.error?.message);
};

View File

@ -2,8 +2,8 @@ import pino from "pino";
import { Transform, TransformOptions } from "stream";
import { Message } from "@smithy/eventstream-codec";
import { APIFormat } from "../../../../shared/key-management";
import { makeCompletionSSE } from "../../../../shared/streaming";
import { RetryableError } from "../index";
import { buildSpoofedSSE } from "../error-generator";
type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;
@ -75,7 +75,7 @@ export class SSEStreamAdapter extends Transform {
throw new RetryableError("AWS request throttled mid-stream");
default:
this.log.error({ message, type }, "Received bad AWS stream event");
return makeCompletionSSE({
return buildSpoofedSSE({
format: "anthropic-text",
title: "Proxy stream error",
message:
@ -103,7 +103,7 @@ export class SSEStreamAdapter extends Transform {
return `data: ${JSON.stringify(data)}`;
} else {
this.log.error({ event: data }, "Received bad Google AI event");
return `data: ${makeCompletionSSE({
return `data: ${buildSpoofedSSE({
format: "google-ai",
title: "Proxy stream error",
message:

View File

@ -13,17 +13,19 @@
import crypto from "crypto";
import type { Handler, Request } from "express";
import { BadRequestError, TooManyRequestsError } from "../shared/errors";
import { keyPool } from "../shared/key-management";
import {
getModelFamilyForRequest,
MODEL_FAMILIES,
ModelFamily,
} from "../shared/models";
import { makeCompletionSSE, initializeSseStream } from "../shared/streaming";
import { initializeSseStream } from "../shared/streaming";
import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
import { handleProxyError } from "./middleware/common";
import { sendErrorToClient } from "./middleware/response/error-generator";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
@ -80,10 +82,14 @@ export async function enqueue(req: Request) {
// Re-enqueued requests are not counted towards the limit since they
// already made it through the queue once.
if (req.retryCount === 0) {
throw new Error("Too many agnai.chat requests are already queued");
throw new TooManyRequestsError(
"Too many agnai.chat requests are already queued"
);
}
} else {
throw new Error("Your IP or token already has a request in the queue");
throw new TooManyRequestsError(
"Your IP or user token already has another request in the queue."
);
}
}
@ -101,8 +107,8 @@ export async function enqueue(req: Request) {
}
registerHeartbeat(req);
} else if (getProxyLoad() > LOAD_THRESHOLD) {
throw new Error(
"Due to heavy traffic on this proxy, you must enable streaming for your request."
throw new BadRequestError(
"Due to heavy traffic on this proxy, you must enable streaming in your chat client to use this endpoint."
);
}
@ -354,11 +360,20 @@ export function createQueueMiddleware({
try {
await enqueue(req);
} catch (err: any) {
req.res!.status(429).json({
type: "proxy_error",
message: err.message,
stack: err.stack,
proxy_note: `Only one request can be queued at a time. If you don't have another request queued, your IP or user token might be in use by another request.`,
const title =
err.status === 429
? "Proxy queue error (too many concurrent requests)"
: "Proxy queue error (streaming required)";
sendErrorToClient({
options: {
title,
message: err.message,
format: req.inboundApi,
reqId: req.id,
model: req.body?.model,
},
req,
res,
});
}
};
@ -373,20 +388,17 @@ function killQueuedRequest(req: Request) {
const res = req.res;
try {
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`;
if (res.headersSent) {
const event = makeCompletionSSE({
format: req.inboundApi,
title: "Proxy queue error",
sendErrorToClient({
options: {
title: "Proxy queue error (request killed)",
message,
reqId: String(req.id),
format: req.inboundApi,
reqId: req.id,
model: req.body?.model,
});
res.write(event);
res.write(`data: [DONE]\n\n`);
res.end();
} else {
res.status(500).json({ error: message });
}
},
req,
res,
});
} catch (e) {
req.log.error(e, `Error killing stalled request.`);
}

View File

@ -8,6 +8,7 @@ import { googleAI } from "./google-ai";
import { mistralAI } from "./mistral-ai";
import { aws } from "./aws";
import { azure } from "./azure";
import { sendErrorToClient } from "./middleware/response/error-generator";
const proxyRouter = express.Router();
proxyRouter.use((req, _res, next) => {
@ -46,8 +47,22 @@ proxyRouter.get("*", (req, res, next) => {
}
});
// Handle 404s.
proxyRouter.use((_req, res) => {
res.status(404).json({ error: "Not found" });
proxyRouter.use((req, res) => {
sendErrorToClient({
req,
res,
options: {
title: "Proxy error (HTTP 404 Not Found)",
message: "The requested proxy endpoint does not exist.",
model: req.body?.model,
reqId: req.id,
format: "unknown",
obj: {
proxy_note: "Your chat client is using the wrong endpoint. Please check your configuration.",
requested_url: req.url,
},
},
});
});
export { proxyRouter as proxyRouter };

View File

@ -19,6 +19,7 @@ import { start as startRequestQueue } from "./proxy/queue";
import { init as initUserStore } from "./shared/users/user-store";
import { init as initTokenizers } from "./shared/tokenization";
import { checkOrigin } from "./proxy/check-origin";
import { sendErrorToClient } from "./proxy/middleware/response/error-generator";
const PORT = config.port;
const BIND_ADDRESS = config.bindAddress;
@ -74,21 +75,27 @@ if (config.staticServiceInfo) {
app.use("/", infoPageRouter);
}
app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => {
if (err.status) {
res.status(err.status).json({ error: err.message });
} else {
logger.error(err);
res.status(500).json({
error: {
type: "proxy_error",
message: err.message,
stack: err.stack,
proxy_note: `Reverse proxy encountered an internal server error.`,
app.use(
(err: any, req: express.Request, res: express.Response, _next: unknown) => {
if (!err.status) {
logger.error(err, "Unhandled error in request");
}
sendErrorToClient({
req,
res,
options: {
title: `Proxy error (HTTP ${err.status})`,
message:
"Reverse proxy encountered an unexpected error while processing your request.",
reqId: req.id,
statusCode: err.status,
obj: { error: err.message, stack: err.stack },
format: "unknown",
},
});
}
});
);
app.use((_req: unknown, res: express.Response) => {
res.status(404).json({ error: "Not found" });
});

View File

@ -4,7 +4,7 @@ export class HttpError extends Error {
}
}
export class UserInputError extends HttpError {
export class BadRequestError extends HttpError {
constructor(message: string) {
super(400, message);
}
@ -21,3 +21,9 @@ export class NotFoundError extends HttpError {
super(404, message);
}
}
export class TooManyRequestsError extends HttpError {
constructor(message: string) {
super(429, message);
}
}

View File

@ -1,7 +1,5 @@
import { Response } from "express";
import { IncomingMessage } from "http";
import { assertNever } from "./utils";
import { APIFormat } from "./key-management";
export function initializeSseStream(res: Response) {
res.statusCode = 200;
@ -35,143 +33,3 @@ export function copySseResponseHeaders(
}
}
/**
* Returns an SSE message that looks like a completion event for the service
* that the request is being proxied to. Used to send error messages to the
* client in the middle of a streaming request.
*/
export function makeCompletionSSE({
format,
title,
message,
obj,
reqId,
model = "unknown",
}: {
format: APIFormat;
title: string;
message: string;
obj?: object;
reqId: string | number | object;
model?: string;
}) {
const id = String(reqId);
const content = `\n\n**${title}**\n${message}${
obj ? `\n\`\`\`\n${JSON.stringify(obj, null, 2)}\n\`\`\`\n` : ""
}`;
let event;
switch (format) {
case "openai":
case "mistral-ai":
event = {
id: "chatcmpl-" + id,
object: "chat.completion.chunk",
created: Date.now(),
model,
choices: [{ delta: { content }, index: 0, finish_reason: title }],
};
break;
case "openai-text":
event = {
id: "cmpl-" + id,
object: "text_completion",
created: Date.now(),
choices: [
{ text: content, index: 0, logprobs: null, finish_reason: title },
],
model,
};
break;
case "anthropic-text":
event = {
completion: content,
stop_reason: title,
truncated: false,
stop: null,
model,
log_id: "proxy-req-" + id,
};
break;
case "anthropic-chat":
event = {
type: "content_block_delta",
index: 0,
delta: { type: "text_delta", text: content },
};
break;
case "google-ai":
return JSON.stringify({
candidates: [
{
content: { parts: [{ text: content }], role: "model" },
finishReason: title,
index: 0,
tokenCount: null,
safetyRatings: [],
},
],
});
case "openai-image":
throw new Error(`SSE not supported for ${format} requests`);
default:
assertNever(format);
}
if (format === "anthropic-text") {
return (
["event: completion", `data: ${JSON.stringify(event)}`].join("\n") +
"\n\n"
);
}
// ugh.
if (format === "anthropic-chat") {
return (
[
[
"event: message_start",
`data: ${JSON.stringify({
type: "message_start",
message: {
id: "error-" + id,
type: "message",
role: "assistant",
content: [],
model,
},
})}`,
].join("\n"),
[
"event: content_block_start",
`data: ${JSON.stringify({
type: "content_block_start",
index: 0,
content_block: { type: "text", text: "" },
})}`,
].join("\n"),
["event: content_block_delta", `data: ${JSON.stringify(event)}`].join(
"\n"
),
[
"event: content_block_stop",
`data: ${JSON.stringify({ type: "content_block_stop", index: 0 })}`,
].join("\n"),
[
"event: message_delta",
`data: ${JSON.stringify({
type: "message_delta",
delta: { stop_reason: title, stop_sequence: null, usage: null },
})}`,
],
[
"event: message_stop",
`data: ${JSON.stringify({ type: "message_stop" })}`,
].join("\n"),
].join("\n\n") + "\n\n"
);
}
return `data: ${JSON.stringify(event)}\n\n`;
}

View File

@ -1,7 +1,7 @@
import { Router } from "express";
import { UserPartialSchema } from "../../shared/users/schema";
import * as userStore from "../../shared/users/user-store";
import { ForbiddenError, UserInputError } from "../../shared/errors";
import { ForbiddenError, BadRequestError } from "../../shared/errors";
import { sanitizeAndTrim } from "../../shared/utils";
import { config } from "../../config";
@ -62,7 +62,7 @@ router.post("/edit-nickname", (req, res) => {
const result = schema.safeParse(req.body);
if (!result.success) {
throw new UserInputError(result.error.message);
throw new BadRequestError(result.error.message);
}
const newNickname = result.data.nickname || null;