Refactor handleStreamingResponse to make it less shit (khanon/oai-reverse-proxy!46)

This commit is contained in:
khanon 2023-10-03 06:14:19 +00:00
parent 6a3d753f0d
commit ecf897e685
20 changed files with 778 additions and 438 deletions

View File

@ -42,7 +42,7 @@ export function writeErrorResponse(
// the stream. Otherwise just send a normal error response. // the stream. Otherwise just send a normal error response.
if ( if (
res.headersSent || res.headersSent ||
res.getHeader("content-type") === "text/event-stream" String(res.getHeader("content-type")).startsWith("text/event-stream")
) { ) {
const errorContent = const errorContent =
statusCode === 403 statusCode === 403

View File

@ -166,12 +166,7 @@ function openaiToAnthropic(req: Request) {
throw result.error; throw result.error;
} }
// Anthropic has started versioning their API, indicated by an HTTP header req.headers["anthropic-version"] = "2023-06-01";
// `anthropic-version`. The new June 2023 version is not backwards compatible
// with our OpenAI-to-Anthropic transformations so we need to explicitly
// request the older version for now. 2023-01-01 will be removed in September.
// https://docs.anthropic.com/claude/reference/versioning
req.headers["anthropic-version"] = "2023-01-01";
const { messages, ...rest } = result.data; const { messages, ...rest } = result.data;
const prompt = openAIMessagesToClaudePrompt(messages); const prompt = openAIMessagesToClaudePrompt(messages);

View File

@ -1,44 +1,16 @@
import { Request, Response } from "express"; import { pipeline } from "stream";
import * as http from "http"; import { promisify } from "util";
import { buildFakeSseMessage } from "../common"; import { buildFakeSseMessage } from "../common";
import { RawResponseBodyHandler, decodeResponseBody } from "."; import { decodeResponseBody, RawResponseBodyHandler } from ".";
import { assertNever } from "../../../shared/utils"; import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { ServerSentEventStreamAdapter } from "./sse-stream-adapter"; import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { EventAggregator } from "./streaming/event-aggregator";
import {
copySseResponseHeaders,
initializeSseStream,
} from "../../../shared/streaming";
type OpenAiChatCompletionResponse = { const pipelineAsync = promisify(pipeline);
id: string;
object: string;
created: number;
model: string;
choices: {
message: { role: string; content: string };
finish_reason: string | null;
index: number;
}[];
};
type OpenAiTextCompletionResponse = {
id: string;
object: string;
created: number;
model: string;
choices: {
text: string;
finish_reason: string | null;
index: number;
logprobs: null;
}[];
};
type AnthropicCompletionResponse = {
completion: string;
stop_reason: string;
truncated: boolean;
stop: any;
model: string;
log_id: string;
exception: null;
};
/** /**
* Consume the SSE stream and forward events to the client. Once the stream is * Consume the SSE stream and forward events to the client. Once the stream is
@ -49,370 +21,67 @@ type AnthropicCompletionResponse = {
* in the event a streamed request results in a non-200 response, we need to * in the event a streamed request results in a non-200 response, we need to
* fall back to the non-streaming response handler so that the error handler * fall back to the non-streaming response handler so that the error handler
* can inspect the error response. * can inspect the error response.
*
* Currently most frontends don't support Anthropic streaming, so users can opt
* to send requests for Claude models via an endpoint that accepts OpenAI-
* compatible requests and translates the received Anthropic SSE events into
* OpenAI ones, essentially pretending to be an OpenAI streaming API.
*/ */
export const handleStreamedResponse: RawResponseBodyHandler = async ( export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes, proxyRes,
req, req,
res res
) => { ) => {
// If these differ, the user is using the OpenAI-compatibile endpoint, so const { hash } = req.key!;
// we need to translate the SSE events into OpenAI completion events for their
// frontend.
if (!req.isStreaming) { if (!req.isStreaming) {
const err = new Error( throw new Error("handleStreamedResponse called for non-streaming request.");
"handleStreamedResponse called for non-streaming request."
);
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
throw err;
} }
const key = req.key!; if (proxyRes.statusCode! > 201) {
if (proxyRes.statusCode !== 200) { req.isStreaming = false; // Forces non-streaming response handler to execute
// Ensure we use the non-streaming middleware stack since we won't be
// getting any events.
req.isStreaming = false;
req.log.warn( req.log.warn(
{ statusCode: proxyRes.statusCode, key: key.hash }, { statusCode: proxyRes.statusCode, key: hash },
`Streaming request returned error status code. Falling back to non-streaming response handler.` `Streaming request returned error status code. Falling back to non-streaming response handler.`
); );
return decodeResponseBody(proxyRes, req, res); return decodeResponseBody(proxyRes, req, res);
} }
req.log.debug( req.log.debug(
{ headers: proxyRes.headers, key: key.hash }, { headers: proxyRes.headers, key: hash },
`Received SSE headers.` `Starting to proxy SSE stream.`
); );
return new Promise((resolve, reject) => { // Users waiting in the queue already have a SSE connection open for the
req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`); // heartbeat, so we can't always send the stream headers.
if (!res.headersSent) {
copySseResponseHeaders(proxyRes, res);
initializeSseStream(res);
}
// Queued streaming requests will already have a connection open and headers const prefersNativeEvents = req.inboundApi === req.outboundApi;
// sent due to the heartbeat handler. In that case we can just start const contentType = proxyRes.headers["content-type"];
// streaming the response without sending headers.
if (!res.headersSent) {
res.setHeader("Content-Type", "text/event-stream");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
res.setHeader("X-Accel-Buffering", "no");
copyHeaders(proxyRes, res);
res.flushHeaders();
}
const adapter = new ServerSentEventStreamAdapter({ const adapter = new SSEStreamAdapter({ contentType });
isAwsStream: const aggregator = new EventAggregator({ format: req.outboundApi });
proxyRes.headers["content-type"] === const transformer = new SSEMessageTransformer({
"application/vnd.amazon.eventstream", inputFormat: req.outboundApi, // outbound from the request's perspective
inputApiVersion: String(req.headers["anthropic-version"]),
logger: req.log,
requestId: String(req.id),
requestedModel: req.body.model,
})
.on("originalMessage", (msg: string) => {
if (prefersNativeEvents) res.write(msg);
})
.on("data", (msg) => {
if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`);
aggregator.addEvent(msg);
}); });
const events: string[] = []; try {
let lastPosition = 0; await pipelineAsync(proxyRes, adapter, transformer);
let eventCount = 0; req.log.debug({ key: hash }, `Finished proxying SSE stream.`);
res.end();
proxyRes.pipe(adapter); return aggregator.getFinalResponse();
} catch (err) {
adapter.on("data", (chunk: any) => { const errorEvent = buildFakeSseMessage("stream-error", err.message, req);
try { res.write(`${errorEvent}data: [DONE]\n\n`);
const { event, position } = transformEvent({ res.end();
data: chunk.toString(), throw err;
requestApi: req.inboundApi, }
responseApi: req.outboundApi,
lastPosition,
index: eventCount++,
});
events.push(event);
lastPosition = position;
res.write(event + "\n\n");
} catch (err) {
adapter.emit("error", err);
}
});
adapter.on("end", () => {
try {
req.log.info({ key: key.hash }, `Finished proxying SSE stream.`);
const finalBody = convertEventsToFinalResponse(events, req);
res.end();
resolve(finalBody);
} catch (err) {
adapter.emit("error", err);
}
});
adapter.on("error", (err) => {
req.log.error({ error: err, key: key.hash }, `Mid-stream error.`);
const errorEvent = buildFakeSseMessage("stream-error", err.message, req);
res.write(`data: ${JSON.stringify(errorEvent)}\n\ndata: [DONE]\n\n`);
res.end();
reject(err);
});
});
}; };
type SSETransformationArgs = {
data: string;
requestApi: string;
responseApi: string;
lastPosition: number;
index: number;
};
/**
* Transforms SSE events from the given response API into events compatible with
* the API requested by the client.
*/
function transformEvent(params: SSETransformationArgs) {
const { data, requestApi, responseApi } = params;
if (requestApi === responseApi) {
return { position: -1, event: data };
}
const trans = `${requestApi}->${responseApi}`;
switch (trans) {
case "openai->openai-text":
return transformOpenAITextEventToOpenAIChat(params);
case "openai->anthropic":
// TODO: handle new anthropic streaming format
return transformV1AnthropicEventToOpenAI(params);
default:
throw new Error(`Unsupported streaming API transformation. ${trans}`);
}
}
function transformOpenAITextEventToOpenAIChat(params: SSETransformationArgs) {
const { data, index } = params;
if (!data.startsWith("data:")) return { position: -1, event: data };
if (data.startsWith("data: [DONE]")) return { position: -1, event: data };
const event = JSON.parse(data.slice("data: ".length));
// The very first event must be a role assignment with no content.
const createEvent = () => ({
id: event.id,
object: "chat.completion.chunk",
created: event.created,
model: event.model,
choices: [
{
message: { role: "", content: "" } as {
role?: string;
content: string;
},
index: 0,
finish_reason: null,
},
],
});
let buffer = "";
if (index === 0) {
const initialEvent = createEvent();
initialEvent.choices[0].message.role = "assistant";
buffer = `data: ${JSON.stringify(initialEvent)}\n\n`;
}
const newEvent = {
...event,
choices: [
{
...event.choices[0],
delta: { content: event.choices[0].text },
text: undefined,
},
],
};
buffer += `data: ${JSON.stringify(newEvent)}`;
return { position: -1, event: buffer };
}
function transformV1AnthropicEventToOpenAI(params: SSETransformationArgs) {
const { data, lastPosition } = params;
// Anthropic sends the full completion so far with each event whereas OpenAI
// only sends the delta. To make the SSE events compatible, we remove
// everything before `lastPosition` from the completion.
if (!data.startsWith("data:")) {
return { position: lastPosition, event: data };
}
if (data.startsWith("data: [DONE]")) {
return { position: lastPosition, event: data };
}
const event = JSON.parse(data.slice("data: ".length));
const newEvent = {
id: "ant-" + event.log_id,
object: "chat.completion.chunk",
created: Date.now(),
model: event.model,
choices: [
{
index: 0,
delta: { content: event.completion?.slice(lastPosition) },
finish_reason: event.stop_reason,
},
],
};
return {
position: event.completion.length,
event: `data: ${JSON.stringify(newEvent)}`,
};
}
/** Copy headers, excluding ones we're already setting for the SSE response. */
function copyHeaders(proxyRes: http.IncomingMessage, res: Response) {
const toOmit = [
"content-length",
"content-encoding",
"transfer-encoding",
"content-type",
"connection",
"cache-control",
];
for (const [key, value] of Object.entries(proxyRes.headers)) {
if (!toOmit.includes(key) && value) {
res.setHeader(key, value);
}
}
}
/**
* Converts the list of incremental SSE events into an object that resembles a
* full, non-streamed response from the API so that subsequent middleware can
* operate on it as if it were a normal response.
* Events are expected to be in the format they were received from the API.
*/
function convertEventsToFinalResponse(events: string[], req: Request) {
switch (req.outboundApi) {
case "openai": {
let merged: OpenAiChatCompletionResponse = {
id: "",
object: "",
created: 0,
model: "",
choices: [],
};
merged = events.reduce((acc, event, i) => {
if (!event.startsWith("data: ")) return acc;
if (event === "data: [DONE]") return acc;
const data = JSON.parse(event.slice("data: ".length));
// The first chat chunk only contains the role assignment and metadata
if (i === 0) {
return {
id: data.id,
object: data.object,
created: data.created,
model: data.model,
choices: [
{
message: { role: data.choices[0].delta.role, content: "" },
index: 0,
finish_reason: null,
},
],
};
}
if (data.choices[0].delta.content) {
acc.choices[0].message.content += data.choices[0].delta.content;
}
acc.choices[0].finish_reason = data.choices[0].finish_reason;
return acc;
}, merged);
return merged;
}
case "openai-text": {
let merged: OpenAiTextCompletionResponse = {
id: "",
object: "",
created: 0,
model: "",
choices: [],
// TODO: merge logprobs
};
merged = events.reduce((acc, event) => {
if (!event.startsWith("data: ")) return acc;
if (event === "data: [DONE]") return acc;
const data = JSON.parse(event.slice("data: ".length));
return {
id: data.id,
object: data.object,
created: data.created,
model: data.model,
choices: [
{
text: acc.choices[0]?.text + data.choices[0].text,
index: 0,
finish_reason: data.choices[0].finish_reason,
logprobs: null,
},
],
};
}, merged);
return merged;
}
case "anthropic": {
if (req.headers["anthropic-version"] === "2023-01-01") {
return convertAnthropicV1(events, req);
}
let merged: AnthropicCompletionResponse = {
completion: "",
stop_reason: "",
truncated: false,
stop: null,
model: req.body.model,
log_id: "",
exception: null,
}
merged = events.reduce((acc, event) => {
if (!event.startsWith("data: ")) return acc;
if (event === "data: [DONE]") return acc;
const data = JSON.parse(event.slice("data: ".length));
return {
completion: acc.completion + data.completion,
stop_reason: data.stop_reason,
truncated: data.truncated,
stop: data.stop,
log_id: data.log_id,
exception: data.exception,
model: acc.model,
};
}, merged);
return merged;
}
case "google-palm": {
throw new Error("PaLM streaming not yet supported.");
}
default:
assertNever(req.outboundApi);
}
}
/** Older Anthropic streaming format which sent full completion each time. */
function convertAnthropicV1(
events: string[],
req: Request
) {
const lastEvent = events[events.length - 2].toString();
const data = JSON.parse(
lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length)
);
const final: AnthropicCompletionResponse = { ...data, log_id: req.id };
return final;
}

View File

@ -4,13 +4,16 @@ import * as http from "http";
import util from "util"; import util from "util";
import zlib from "zlib"; import zlib from "zlib";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import { enqueue, trackWaitTime } from "../../queue";
import { HttpError } from "../../../shared/errors";
import { keyPool } from "../../../shared/key-management"; import { keyPool } from "../../../shared/key-management";
import { getOpenAIModelFamily } from "../../../shared/models"; import { getOpenAIModelFamily } from "../../../shared/models";
import { enqueue, trackWaitTime } from "../../queue"; import { countTokens } from "../../../shared/tokenization";
import { import {
incrementPromptCount, incrementPromptCount,
incrementTokenCount, incrementTokenCount,
} from "../../../shared/users/user-store"; } from "../../../shared/users/user-store";
import { assertNever } from "../../../shared/utils";
import { import {
getCompletionFromBody, getCompletionFromBody,
isCompletionRequest, isCompletionRequest,
@ -18,8 +21,6 @@ import {
} from "../common"; } from "../common";
import { handleStreamedResponse } from "./handle-streamed-response"; import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt"; import { logPrompt } from "./log-prompt";
import { countTokens } from "../../../shared/tokenization";
import { assertNever } from "../../../shared/utils";
const DECODER_MAP = { const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip), gzip: util.promisify(zlib.gunzip),
@ -83,7 +84,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
? handleStreamedResponse ? handleStreamedResponse
: decodeResponseBody; : decodeResponseBody;
let lastMiddlewareName = initialHandler.name; let lastMiddleware = initialHandler.name;
try { try {
const body = await initialHandler(proxyRes, req, res); const body = await initialHandler(proxyRes, req, res);
@ -112,37 +113,38 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
} }
for (const middleware of middlewareStack) { for (const middleware of middlewareStack) {
lastMiddlewareName = middleware.name; lastMiddleware = middleware.name;
await middleware(proxyRes, req, res, body); await middleware(proxyRes, req, res, body);
} }
trackWaitTime(req); trackWaitTime(req);
} catch (error: any) { } catch (error) {
// Hack: if the error is a retryable rate-limit error, the request has // Hack: if the error is a retryable rate-limit error, the request has
// been re-enqueued and we can just return without doing anything else. // been re-enqueued and we can just return without doing anything else.
if (error instanceof RetryableError) { if (error instanceof RetryableError) {
return; return;
} }
const errorData = { // Already logged and responded to the client by handleUpstreamErrors
error: error.stack, if (error instanceof HttpError) {
thrownBy: lastMiddlewareName, if (!res.writableEnded) res.end();
key: req.key?.hash,
};
const message = `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`;
if (res.headersSent) {
req.log.error(errorData, message);
// This should have already been handled by the error handler, but
// just in case...
if (!res.writableEnded) {
res.end();
}
return; return;
} }
logger.error(errorData, message);
res const { stack, message } = error;
.status(500) const info = { stack, lastMiddleware, key: req.key?.hash };
.json({ error: "Internal server error", proxy_note: message }); const description = `Error while executing proxy response middleware: ${lastMiddleware} (${message})`;
if (res.headersSent) {
req.log.error(info, description);
if (!res.writableEnded) res.end();
return;
} else {
req.log.error(info, description);
res
.status(500)
.json({ error: "Internal server error", proxy_note: description });
}
} }
}; };
}; };
@ -203,7 +205,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
return resolve(body.toString()); return resolve(body.toString());
} catch (error: any) { } catch (error: any) {
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`; const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
logger.warn({ error, key: req.key?.hash }, errorMessage); logger.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
writeErrorResponse(req, res, 500, { error: errorMessage }); writeErrorResponse(req, res, 500, { error: errorMessage });
return reject(errorMessage); return reject(errorMessage);
} }
@ -223,7 +225,7 @@ type ProxiedErrorPayload = {
* an error to stop the middleware stack. * an error to stop the middleware stack.
* On 429 errors, if request queueing is enabled, the request will be silently * On 429 errors, if request queueing is enabled, the request will be silently
* re-enqueued. Otherwise, the request will be rejected with an error payload. * re-enqueued. Otherwise, the request will be rejected with an error payload.
* @throws {Error} On HTTP error status code from upstream service * @throws {HttpError} On HTTP error status code from upstream service
*/ */
const handleUpstreamErrors: ProxyResHandlerWithBody = async ( const handleUpstreamErrors: ProxyResHandlerWithBody = async (
proxyRes, proxyRes,
@ -258,7 +260,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
proxy_note: `This is likely a temporary error with the upstream service.`, proxy_note: `This is likely a temporary error with the upstream service.`,
}; };
writeErrorResponse(req, res, statusCode, errorObject); writeErrorResponse(req, res, statusCode, errorObject);
throw new Error(parseError.message); throw new HttpError(statusCode, parseError.message);
} }
const errorType = const errorType =
@ -371,7 +373,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} }
writeErrorResponse(req, res, statusCode, errorPayload); writeErrorResponse(req, res, statusCode, errorPayload);
throw new Error(errorPayload.error?.message); throw new HttpError(statusCode, errorPayload.error?.message);
}; };
/** /**

View File

@ -0,0 +1,48 @@
import { OpenAIChatCompletionStreamEvent } from "../index";
export type AnthropicCompletionResponse = {
completion: string;
stop_reason: string;
truncated: boolean;
stop: any;
model: string;
log_id: string;
exception: null;
};
/**
* Given a list of OpenAI chat completion events, compiles them into a single
* finalized Anthropic completion response so that non-streaming middleware
* can operate on it as if it were a blocking response.
*/
export function mergeEventsForAnthropic(
events: OpenAIChatCompletionStreamEvent[]
): AnthropicCompletionResponse {
let merged: AnthropicCompletionResponse = {
log_id: "",
exception: null,
model: "",
completion: "",
stop_reason: "",
truncated: false,
stop: null,
};
merged = events.reduce((acc, event, i) => {
// The first event will only contain role assignment and response metadata
if (i === 0) {
acc.log_id = event.id;
acc.model = event.model;
acc.completion = "";
acc.stop_reason = "";
return acc;
}
acc.stop_reason = event.choices[0].finish_reason ?? "";
if (event.choices[0].delta.content) {
acc.completion += event.choices[0].delta.content;
}
return acc;
}, merged);
return merged;
}

View File

@ -0,0 +1,58 @@
import { OpenAIChatCompletionStreamEvent } from "../index";
export type OpenAiChatCompletionResponse = {
id: string;
object: string;
created: number;
model: string;
choices: {
message: { role: string; content: string };
finish_reason: string | null;
index: number;
}[];
};
/**
* Given a list of OpenAI chat completion events, compiles them into a single
* finalized OpenAI chat completion response so that non-streaming middleware
* can operate on it as if it were a blocking response.
*/
export function mergeEventsForOpenAIChat(
events: OpenAIChatCompletionStreamEvent[]
): OpenAiChatCompletionResponse {
let merged: OpenAiChatCompletionResponse = {
id: "",
object: "",
created: 0,
model: "",
choices: [],
};
merged = events.reduce((acc, event, i) => {
// The first event will only contain role assignment and response metadata
if (i === 0) {
acc.id = event.id;
acc.object = event.object;
acc.created = event.created;
acc.model = event.model;
acc.choices = [
{
index: 0,
message: {
role: event.choices[0].delta.role ?? "assistant",
content: "",
},
finish_reason: null,
},
];
return acc;
}
acc.choices[0].finish_reason = event.choices[0].finish_reason;
if (event.choices[0].delta.content) {
acc.choices[0].message.content += event.choices[0].delta.content;
}
return acc;
}, merged);
return merged;
}

View File

@ -0,0 +1,57 @@
import { OpenAIChatCompletionStreamEvent } from "../index";
export type OpenAiTextCompletionResponse = {
id: string;
object: string;
created: number;
model: string;
choices: {
text: string;
finish_reason: string | null;
index: number;
logprobs: null;
}[];
};
/**
* Given a list of OpenAI chat completion events, compiles them into a single
* finalized OpenAI text completion response so that non-streaming middleware
* can operate on it as if it were a blocking response.
*/
export function mergeEventsForOpenAIText(
events: OpenAIChatCompletionStreamEvent[]
): OpenAiTextCompletionResponse {
let merged: OpenAiTextCompletionResponse = {
id: "",
object: "",
created: 0,
model: "",
choices: [],
};
merged = events.reduce((acc, event, i) => {
// The first event will only contain role assignment and response metadata
if (i === 0) {
acc.id = event.id;
acc.object = event.object;
acc.created = event.created;
acc.model = event.model;
acc.choices = [
{
text: "",
index: 0,
finish_reason: null,
logprobs: null,
},
];
return acc;
}
acc.choices[0].finish_reason = event.choices[0].finish_reason;
if (event.choices[0].delta.content) {
acc.choices[0].text += event.choices[0].delta.content;
}
return acc;
}, merged);
return merged;
}

View File

@ -0,0 +1,41 @@
import { APIFormat } from "../../../../shared/key-management";
import { assertNever } from "../../../../shared/utils";
import {
mergeEventsForAnthropic,
mergeEventsForOpenAIChat,
mergeEventsForOpenAIText,
OpenAIChatCompletionStreamEvent
} from "./index";
/**
* Collects SSE events containing incremental chat completion responses and
* compiles them into a single finalized response for downstream middleware.
*/
export class EventAggregator {
private readonly format: APIFormat;
private readonly events: OpenAIChatCompletionStreamEvent[];
constructor({ format }: { format: APIFormat }) {
this.events = [];
this.format = format;
}
addEvent(event: OpenAIChatCompletionStreamEvent) {
this.events.push(event);
}
getFinalResponse() {
switch (this.format) {
case "openai":
return mergeEventsForOpenAIChat(this.events);
case "openai-text":
return mergeEventsForOpenAIText(this.events);
case "anthropic":
return mergeEventsForAnthropic(this.events);
case "google-palm":
throw new Error("Google PaLM API does not support streaming responses");
default:
assertNever(this.format);
}
}
}

View File

@ -0,0 +1,31 @@
export type SSEResponseTransformArgs = {
data: string;
lastPosition: number;
index: number;
fallbackId: string;
fallbackModel: string;
};
export type OpenAIChatCompletionStreamEvent = {
id: string;
object: "chat.completion.chunk";
created: number;
model: string;
choices: {
index: number;
delta: { role?: string; content?: string };
finish_reason: string | null;
}[];
}
export type StreamingCompletionTransformer = (
params: SSEResponseTransformArgs
) => { position: number; event?: OpenAIChatCompletionStreamEvent };
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropic } from "./aggregators/anthropic";

View File

@ -0,0 +1,29 @@
export type ServerSentEvent = { id?: string; type?: string; data: string };
/** Given a string of SSE data, parse it into a `ServerSentEvent` object. */
export function parseEvent(event: string) {
const buffer: ServerSentEvent = { data: "" };
return event.split(/\r?\n/).reduce(parseLine, buffer)
}
function parseLine(event: ServerSentEvent, line: string) {
const separator = line.indexOf(":");
const field = separator === -1 ? line : line.slice(0,separator);
const value = separator === -1 ? "" : line.slice(separator + 1);
switch (field) {
case 'id':
event.id = value.trim()
break
case 'event':
event.type = value.trim()
break
case 'data':
event.data += value.trimStart()
break
default:
break
}
return event
}

View File

@ -0,0 +1,123 @@
import { Transform, TransformOptions } from "stream";
import { logger } from "../../../../logger";
import { APIFormat } from "../../../../shared/key-management";
import { assertNever } from "../../../../shared/utils";
import {
anthropicV1ToOpenAI,
anthropicV2ToOpenAI,
OpenAIChatCompletionStreamEvent,
openAITextToOpenAIChat,
StreamingCompletionTransformer,
} from "./index";
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
const genlog = logger.child({ module: "sse-transformer" });
type SSEMessageTransformerOptions = TransformOptions & {
requestedModel: string;
requestId: string;
inputFormat: APIFormat;
inputApiVersion?: string;
logger?: typeof logger;
};
/**
* Transforms SSE messages from one API format to OpenAI chat.completion.chunks.
* Emits the original string SSE message as an "originalMessage" event.
*/
export class SSEMessageTransformer extends Transform {
private lastPosition: number;
private msgCount: number;
private readonly transformFn: StreamingCompletionTransformer;
private readonly log;
private readonly fallbackId: string;
private readonly fallbackModel: string;
constructor(options: SSEMessageTransformerOptions) {
super({ ...options, readableObjectMode: true });
this.log = options.logger?.child({ module: "sse-transformer" }) ?? genlog;
this.lastPosition = 0;
this.msgCount = 0;
this.transformFn = getTransformer(
options.inputFormat,
options.inputApiVersion
);
this.fallbackId = options.requestId;
this.fallbackModel = options.requestedModel;
this.log.debug(
{
fn: this.transformFn.name,
format: options.inputFormat,
version: options.inputApiVersion,
},
"Selected SSE transformer"
);
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
try {
const originalMessage = chunk.toString();
const { event: transformedMessage, position: newPosition } =
this.transformFn({
data: originalMessage,
lastPosition: this.lastPosition,
index: this.msgCount++,
fallbackId: this.fallbackId,
fallbackModel: this.fallbackModel,
});
this.lastPosition = newPosition;
this.emit("originalMessage", originalMessage);
// Some events may not be transformed, e.g. ping events
if (!transformedMessage) return callback();
if (this.msgCount === 1) {
this.push(createInitialMessage(transformedMessage));
}
this.push(transformedMessage);
callback();
} catch (err) {
this.log.error(err, "Error transforming SSE message");
callback(err);
}
}
}
function getTransformer(
responseApi: APIFormat,
version?: string
): StreamingCompletionTransformer {
switch (responseApi) {
case "openai":
return passthroughToOpenAI;
case "openai-text":
return openAITextToOpenAIChat;
case "anthropic":
return version === "2023-01-01"
? anthropicV1ToOpenAI
: anthropicV2ToOpenAI;
case "google-palm":
throw new Error("Google PaLM does not support streaming responses");
default:
assertNever(responseApi);
}
}
/**
* OpenAI streaming chat completions start with an event that contains only the
* metadata and role (always 'assistant') for the response. To simulate this
* for APIs where the first event contains actual content, we create a fake
* initial event with no content but correct metadata.
*/
function createInitialMessage(
event: OpenAIChatCompletionStreamEvent
): OpenAIChatCompletionStreamEvent {
return {
...event,
choices: event.choices.map((choice) => ({
...choice,
delta: { role: "assistant", content: "" },
})),
};
}

View File

@ -1,11 +1,11 @@
import { Transform, TransformOptions } from "stream"; import { Transform, TransformOptions } from "stream";
// @ts-ignore // @ts-ignore
import { Parser } from "lifion-aws-event-stream"; import { Parser } from "lifion-aws-event-stream";
import { logger } from "../../../logger"; import { logger } from "../../../../logger";
const log = logger.child({ module: "sse-stream-adapter" }); const log = logger.child({ module: "sse-stream-adapter" });
type SSEStreamAdapterOptions = TransformOptions & { isAwsStream?: boolean }; type SSEStreamAdapterOptions = TransformOptions & { contentType?: string };
type AwsEventStreamMessage = { type AwsEventStreamMessage = {
headers: { ":message-type": "event" | "exception" }; headers: { ":message-type": "event" | "exception" };
payload: { message?: string /** base64 encoded */; bytes?: string }; payload: { message?: string /** base64 encoded */; bytes?: string };
@ -15,24 +15,25 @@ type AwsEventStreamMessage = {
* Receives either text chunks or AWS binary event stream chunks and emits * Receives either text chunks or AWS binary event stream chunks and emits
* full SSE events. * full SSE events.
*/ */
export class ServerSentEventStreamAdapter extends Transform { export class SSEStreamAdapter extends Transform {
private readonly isAwsStream; private readonly isAwsStream;
private parser = new Parser(); private parser = new Parser();
private partialMessage = ""; private partialMessage = "";
constructor(options?: SSEStreamAdapterOptions) { constructor(options?: SSEStreamAdapterOptions) {
super(options); super(options);
this.isAwsStream = options?.isAwsStream || false; this.isAwsStream =
options?.contentType === "application/vnd.amazon.eventstream";
this.parser.on("data", (data: AwsEventStreamMessage) => { this.parser.on("data", (data: AwsEventStreamMessage) => {
const message = this.processAwsEvent(data); const message = this.processAwsEvent(data);
if (message) { if (message) {
this.push(Buffer.from(message, "utf8")); this.push(Buffer.from(message + "\n\n"), "utf8");
} }
}); });
} }
processAwsEvent(event: AwsEventStreamMessage): string | null { protected processAwsEvent(event: AwsEventStreamMessage): string | null {
const { payload, headers } = event; const { payload, headers } = event;
if (headers[":message-type"] === "exception" || !payload.bytes) { if (headers[":message-type"] === "exception" || !payload.bytes) {
log.error( log.error(
@ -42,7 +43,14 @@ export class ServerSentEventStreamAdapter extends Transform {
const message = JSON.stringify(event); const message = JSON.stringify(event);
return getFakeErrorCompletion("proxy AWS error", message); return getFakeErrorCompletion("proxy AWS error", message);
} else { } else {
return `data: ${Buffer.from(payload.bytes, "base64").toString("utf8")}`; const { bytes } = payload;
// technically this is a transformation but we don't really distinguish
// between aws claude and anthropic claude at the APIFormat level, so
// these will short circuit the message transformer
return [
"event: completion",
`data: ${Buffer.from(bytes, "base64").toString("utf8")}`,
].join("\n");
} }
} }
@ -55,11 +63,15 @@ export class ServerSentEventStreamAdapter extends Transform {
// so we need to buffer and emit separate stream events for full // so we need to buffer and emit separate stream events for full
// messages so we can parse/transform them properly. // messages so we can parse/transform them properly.
const str = chunk.toString("utf8"); const str = chunk.toString("utf8");
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/); const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/);
this.partialMessage = fullMessages.pop() || ""; this.partialMessage = fullMessages.pop() || "";
for (const message of fullMessages) { for (const message of fullMessages) {
this.push(message); // Mixing line endings will break some clients and our request queue
// will have already sent \n for heartbeats, so we need to normalize
// to \n.
this.push(message.replace(/\r\n/g, "\n") + "\n\n");
} }
} }
callback(); callback();
@ -72,7 +84,7 @@ export class ServerSentEventStreamAdapter extends Transform {
function getFakeErrorCompletion(type: string, message: string) { function getFakeErrorCompletion(type: string, message: string) {
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`; const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
const fakeEvent = { const fakeEvent = JSON.stringify({
log_id: "aws-proxy-sse-message", log_id: "aws-proxy-sse-message",
stop_reason: type, stop_reason: type,
completion: completion:
@ -80,6 +92,6 @@ function getFakeErrorCompletion(type: string, message: string) {
truncated: false, truncated: false,
stop: null, stop: null,
model: "", model: "",
}; });
return `data: ${JSON.stringify(fakeEvent)}\n\n`; return ["event: completion", `data: ${fakeEvent}\n\n`].join("\n");
} }

View File

@ -0,0 +1,67 @@
import { StreamingCompletionTransformer } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "anthropic-v1-to-openai",
});
type AnthropicV1StreamEvent = {
log_id?: string;
model?: string;
completion: string;
stop_reason: string;
};
/**
* Transforms an incoming Anthropic SSE (2023-01-01 API) to an equivalent
* OpenAI chat.completion.chunk SSE.
*/
export const anthropicV1ToOpenAI: StreamingCompletionTransformer = (params) => {
const { data, lastPosition } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: lastPosition };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: lastPosition };
}
// Anthropic sends the full completion so far with each event whereas OpenAI
// only sends the delta. To make the SSE events compatible, we remove
// everything before `lastPosition` from the completion.
const newEvent = {
id: "ant-" + (completionEvent.log_id ?? params.fallbackId),
object: "chat.completion.chunk" as const,
created: Date.now(),
model: completionEvent.model ?? params.fallbackModel,
choices: [
{
index: 0,
delta: { content: completionEvent.completion?.slice(lastPosition) },
finish_reason: completionEvent.stop_reason,
},
],
};
return { position: completionEvent.completion.length, event: newEvent };
};
function asCompletion(event: ServerSentEvent): AnthropicV1StreamEvent | null {
try {
const parsed = JSON.parse(event.data);
if (parsed.completion !== undefined && parsed.stop_reason !== undefined) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid event");
}
return null;
}

View File

@ -0,0 +1,66 @@
import { StreamingCompletionTransformer } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "anthropic-v2-to-openai",
});
type AnthropicV2StreamEvent = {
log_id?: string;
model?: string;
completion: string;
stop_reason: string;
};
/**
* Transforms an incoming Anthropic SSE (2023-06-01 API) to an equivalent
* OpenAI chat.completion.chunk SSE.
*/
export const anthropicV2ToOpenAI: StreamingCompletionTransformer = (params) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
const newEvent = {
id: "ant-" + (completionEvent.log_id ?? params.fallbackId),
object: "chat.completion.chunk" as const,
created: Date.now(),
model: completionEvent.model ?? params.fallbackModel,
choices: [
{
index: 0,
delta: { content: completionEvent.completion },
finish_reason: completionEvent.stop_reason,
},
],
};
return { position: completionEvent.completion.length, event: newEvent };
};
function asCompletion(event: ServerSentEvent): AnthropicV2StreamEvent | null {
if (event.type === "ping") return null;
try {
const parsed = JSON.parse(event.data);
if (parsed.completion !== undefined && parsed.stop_reason !== undefined) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid event");
}
return null;
}

View File

@ -0,0 +1,68 @@
import { SSEResponseTransformArgs } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "openai-text-to-openai",
});
type OpenAITextCompletionStreamEvent = {
id: string;
object: "text_completion";
created: number;
choices: {
text: string;
index: number;
logprobs: null;
finish_reason: string | null;
}[];
model: string;
};
export const openAITextToOpenAIChat = (params: SSEResponseTransformArgs) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
const newEvent = {
id: completionEvent.id,
object: "chat.completion.chunk" as const,
created: completionEvent.created,
model: completionEvent.model,
choices: [
{
index: completionEvent.choices[0].index,
delta: { content: completionEvent.choices[0].text },
finish_reason: completionEvent.choices[0].finish_reason,
},
],
};
return { position: -1, event: newEvent };
};
function asCompletion(
event: ServerSentEvent
): OpenAITextCompletionStreamEvent | null {
try {
const parsed = JSON.parse(event.data);
if (Array.isArray(parsed.choices) && parsed.choices[0].text !== undefined) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid data event");
}
return null;
}

View File

@ -0,0 +1,38 @@
import {
OpenAIChatCompletionStreamEvent,
SSEResponseTransformArgs,
} from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "openai-to-openai",
});
export const passthroughToOpenAI = (params: SSEResponseTransformArgs) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
return { position: -1, event: completionEvent };
};
function asCompletion(
event: ServerSentEvent
): OpenAIChatCompletionStreamEvent | null {
try {
return JSON.parse(event.data);
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid event");
}
return null;
}

View File

@ -23,10 +23,11 @@ import {
getOpenAIModelFamily, getOpenAIModelFamily,
ModelFamily, ModelFamily,
} from "../shared/models"; } from "../shared/models";
import { initializeSseStream } from "../shared/streaming";
import { assertNever } from "../shared/utils";
import { logger } from "../logger"; import { logger } from "../logger";
import { AGNAI_DOT_CHAT_IP } from "./rate-limit"; import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
import { buildFakeSseMessage } from "./middleware/common"; import { buildFakeSseMessage } from "./middleware/common";
import { assertNever } from "../shared/utils";
const queue: Request[] = []; const queue: Request[] = [];
const log = logger.child({ module: "request-queue" }); const log = logger.child({ module: "request-queue" });
@ -352,14 +353,8 @@ function killQueuedRequest(req: Request) {
} }
function initStreaming(req: Request) { function initStreaming(req: Request) {
req.log.info(`Initiating streaming for new queued request.`);
const res = req.res!; const res = req.res!;
res.statusCode = 200; initializeSseStream(res);
res.setHeader("Content-Type", "text/event-stream");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
res.setHeader("X-Accel-Buffering", "no"); // nginx-specific fix
res.flushHeaders();
if (req.query.badSseParser) { if (req.query.badSseParser) {
// Some clients have a broken SSE parser that doesn't handle comments // Some clients have a broken SSE parser that doesn't handle comments
@ -368,7 +363,6 @@ function initStreaming(req: Request) {
return; return;
} }
res.write("\n");
res.write(": joining queue\n\n"); res.write(": joining queue\n\n");
} }

View File

@ -7,6 +7,14 @@ import { googlePalm } from "./palm";
import { aws } from "./aws"; import { aws } from "./aws";
const proxyRouter = express.Router(); const proxyRouter = express.Router();
proxyRouter.use((req, _res, next) => {
if (req.headers.expect) {
// node-http-proxy does not like it when clients send `expect: 100-continue`
// and will stall. none of the upstream APIs use this header anyway.
delete req.headers.expect;
}
next();
});
proxyRouter.use( proxyRouter.use(
express.json({ limit: "1536kb" }), express.json({ limit: "1536kb" }),
express.urlencoded({ extended: true, limit: "1536kb" }) express.urlencoded({ extended: true, limit: "1536kb" })

34
src/shared/streaming.ts Normal file
View File

@ -0,0 +1,34 @@
import { Response } from "express";
import { IncomingMessage } from "http";
export function initializeSseStream(res: Response) {
res.statusCode = 200;
res.setHeader("Content-Type", "text/event-stream; charset=utf-8");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
res.setHeader("X-Accel-Buffering", "no"); // nginx-specific fix
res.flushHeaders();
}
/**
* Copies headers received from upstream API to the SSE response, excluding
* ones we need to set ourselves for SSE to work.
*/
export function copySseResponseHeaders(
proxyRes: IncomingMessage,
res: Response
) {
const toOmit = [
"content-length",
"content-encoding",
"transfer-encoding",
"content-type",
"connection",
"cache-control",
];
for (const [key, value] of Object.entries(proxyRes.headers)) {
if (!toOmit.includes(key) && value) {
res.setHeader(key, value);
}
}
}

View File

@ -277,7 +277,7 @@ function cleanupExpiredTokens() {
deleted++; deleted++;
} }
} }
log.debug({ disabled, deleted }, "Expired tokens cleaned up."); log.trace({ disabled, deleted }, "Expired tokens cleaned up.");
} }
function refreshAllQuotas() { function refreshAllQuotas() {