From ecf897e685e65bec6887958f370df7149d80ff56 Mon Sep 17 00:00:00 2001 From: khanon Date: Tue, 3 Oct 2023 06:14:19 +0000 Subject: [PATCH] Refactor handleStreamingResponse to make it less shit (khanon/oai-reverse-proxy!46) --- src/proxy/middleware/common.ts | 2 +- .../request/transform-outbound-payload.ts | 7 +- .../response/handle-streamed-response.ts | 435 +++--------------- src/proxy/middleware/response/index.ts | 56 +-- .../streaming/aggregators/anthropic.ts | 48 ++ .../streaming/aggregators/openai-chat.ts | 58 +++ .../streaming/aggregators/openai-text.ts | 57 +++ .../response/streaming/event-aggregator.ts | 41 ++ .../middleware/response/streaming/index.ts | 31 ++ .../response/streaming/parse-sse.ts | 29 ++ .../streaming/sse-message-transformer.ts | 123 +++++ .../{ => streaming}/sse-stream-adapter.ts | 34 +- .../transformers/anthropic-v1-to-openai.ts | 67 +++ .../transformers/anthropic-v2-to-openai.ts | 66 +++ .../transformers/openai-text-to-openai.ts | 68 +++ .../transformers/passthrough-to-openai.ts | 38 ++ src/proxy/queue.ts | 12 +- src/proxy/routes.ts | 8 + src/shared/streaming.ts | 34 ++ src/shared/users/user-store.ts | 2 +- 20 files changed, 778 insertions(+), 438 deletions(-) create mode 100644 src/proxy/middleware/response/streaming/aggregators/anthropic.ts create mode 100644 src/proxy/middleware/response/streaming/aggregators/openai-chat.ts create mode 100644 src/proxy/middleware/response/streaming/aggregators/openai-text.ts create mode 100644 src/proxy/middleware/response/streaming/event-aggregator.ts create mode 100644 src/proxy/middleware/response/streaming/index.ts create mode 100644 src/proxy/middleware/response/streaming/parse-sse.ts create mode 100644 src/proxy/middleware/response/streaming/sse-message-transformer.ts rename src/proxy/middleware/response/{ => streaming}/sse-stream-adapter.ts (66%) create mode 100644 src/proxy/middleware/response/streaming/transformers/anthropic-v1-to-openai.ts create mode 100644 src/proxy/middleware/response/streaming/transformers/anthropic-v2-to-openai.ts create mode 100644 src/proxy/middleware/response/streaming/transformers/openai-text-to-openai.ts create mode 100644 src/proxy/middleware/response/streaming/transformers/passthrough-to-openai.ts create mode 100644 src/shared/streaming.ts diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index f5c8bd1..a50728e 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -42,7 +42,7 @@ export function writeErrorResponse( // the stream. Otherwise just send a normal error response. if ( res.headersSent || - res.getHeader("content-type") === "text/event-stream" + String(res.getHeader("content-type")).startsWith("text/event-stream") ) { const errorContent = statusCode === 403 diff --git a/src/proxy/middleware/request/transform-outbound-payload.ts b/src/proxy/middleware/request/transform-outbound-payload.ts index 7659ab6..7c1a3c0 100644 --- a/src/proxy/middleware/request/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/transform-outbound-payload.ts @@ -166,12 +166,7 @@ function openaiToAnthropic(req: Request) { throw result.error; } - // Anthropic has started versioning their API, indicated by an HTTP header - // `anthropic-version`. The new June 2023 version is not backwards compatible - // with our OpenAI-to-Anthropic transformations so we need to explicitly - // request the older version for now. 2023-01-01 will be removed in September. - // https://docs.anthropic.com/claude/reference/versioning - req.headers["anthropic-version"] = "2023-01-01"; + req.headers["anthropic-version"] = "2023-06-01"; const { messages, ...rest } = result.data; const prompt = openAIMessagesToClaudePrompt(messages); diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index 605a5c5..7696ca0 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -1,44 +1,16 @@ -import { Request, Response } from "express"; -import * as http from "http"; +import { pipeline } from "stream"; +import { promisify } from "util"; import { buildFakeSseMessage } from "../common"; -import { RawResponseBodyHandler, decodeResponseBody } from "."; -import { assertNever } from "../../../shared/utils"; -import { ServerSentEventStreamAdapter } from "./sse-stream-adapter"; +import { decodeResponseBody, RawResponseBodyHandler } from "."; +import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; +import { SSEMessageTransformer } from "./streaming/sse-message-transformer"; +import { EventAggregator } from "./streaming/event-aggregator"; +import { + copySseResponseHeaders, + initializeSseStream, +} from "../../../shared/streaming"; -type OpenAiChatCompletionResponse = { - id: string; - object: string; - created: number; - model: string; - choices: { - message: { role: string; content: string }; - finish_reason: string | null; - index: number; - }[]; -}; - -type OpenAiTextCompletionResponse = { - id: string; - object: string; - created: number; - model: string; - choices: { - text: string; - finish_reason: string | null; - index: number; - logprobs: null; - }[]; -}; - -type AnthropicCompletionResponse = { - completion: string; - stop_reason: string; - truncated: boolean; - stop: any; - model: string; - log_id: string; - exception: null; -}; +const pipelineAsync = promisify(pipeline); /** * Consume the SSE stream and forward events to the client. Once the stream is @@ -49,370 +21,67 @@ type AnthropicCompletionResponse = { * in the event a streamed request results in a non-200 response, we need to * fall back to the non-streaming response handler so that the error handler * can inspect the error response. - * - * Currently most frontends don't support Anthropic streaming, so users can opt - * to send requests for Claude models via an endpoint that accepts OpenAI- - * compatible requests and translates the received Anthropic SSE events into - * OpenAI ones, essentially pretending to be an OpenAI streaming API. */ export const handleStreamedResponse: RawResponseBodyHandler = async ( proxyRes, req, res ) => { - // If these differ, the user is using the OpenAI-compatibile endpoint, so - // we need to translate the SSE events into OpenAI completion events for their - // frontend. + const { hash } = req.key!; if (!req.isStreaming) { - const err = new Error( - "handleStreamedResponse called for non-streaming request." - ); - req.log.error({ stack: err.stack, api: req.inboundApi }, err.message); - throw err; + throw new Error("handleStreamedResponse called for non-streaming request."); } - const key = req.key!; - if (proxyRes.statusCode !== 200) { - // Ensure we use the non-streaming middleware stack since we won't be - // getting any events. - req.isStreaming = false; + if (proxyRes.statusCode! > 201) { + req.isStreaming = false; // Forces non-streaming response handler to execute req.log.warn( - { statusCode: proxyRes.statusCode, key: key.hash }, + { statusCode: proxyRes.statusCode, key: hash }, `Streaming request returned error status code. Falling back to non-streaming response handler.` ); return decodeResponseBody(proxyRes, req, res); } req.log.debug( - { headers: proxyRes.headers, key: key.hash }, - `Received SSE headers.` + { headers: proxyRes.headers, key: hash }, + `Starting to proxy SSE stream.` ); - return new Promise((resolve, reject) => { - req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`); + // Users waiting in the queue already have a SSE connection open for the + // heartbeat, so we can't always send the stream headers. + if (!res.headersSent) { + copySseResponseHeaders(proxyRes, res); + initializeSseStream(res); + } - // Queued streaming requests will already have a connection open and headers - // sent due to the heartbeat handler. In that case we can just start - // streaming the response without sending headers. - if (!res.headersSent) { - res.setHeader("Content-Type", "text/event-stream"); - res.setHeader("Cache-Control", "no-cache"); - res.setHeader("Connection", "keep-alive"); - res.setHeader("X-Accel-Buffering", "no"); - copyHeaders(proxyRes, res); - res.flushHeaders(); - } + const prefersNativeEvents = req.inboundApi === req.outboundApi; + const contentType = proxyRes.headers["content-type"]; - const adapter = new ServerSentEventStreamAdapter({ - isAwsStream: - proxyRes.headers["content-type"] === - "application/vnd.amazon.eventstream", + const adapter = new SSEStreamAdapter({ contentType }); + const aggregator = new EventAggregator({ format: req.outboundApi }); + const transformer = new SSEMessageTransformer({ + inputFormat: req.outboundApi, // outbound from the request's perspective + inputApiVersion: String(req.headers["anthropic-version"]), + logger: req.log, + requestId: String(req.id), + requestedModel: req.body.model, + }) + .on("originalMessage", (msg: string) => { + if (prefersNativeEvents) res.write(msg); + }) + .on("data", (msg) => { + if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`); + aggregator.addEvent(msg); }); - const events: string[] = []; - let lastPosition = 0; - let eventCount = 0; - - proxyRes.pipe(adapter); - - adapter.on("data", (chunk: any) => { - try { - const { event, position } = transformEvent({ - data: chunk.toString(), - requestApi: req.inboundApi, - responseApi: req.outboundApi, - lastPosition, - index: eventCount++, - }); - events.push(event); - lastPosition = position; - res.write(event + "\n\n"); - } catch (err) { - adapter.emit("error", err); - } - }); - - adapter.on("end", () => { - try { - req.log.info({ key: key.hash }, `Finished proxying SSE stream.`); - const finalBody = convertEventsToFinalResponse(events, req); - res.end(); - resolve(finalBody); - } catch (err) { - adapter.emit("error", err); - } - }); - - adapter.on("error", (err) => { - req.log.error({ error: err, key: key.hash }, `Mid-stream error.`); - const errorEvent = buildFakeSseMessage("stream-error", err.message, req); - res.write(`data: ${JSON.stringify(errorEvent)}\n\ndata: [DONE]\n\n`); - res.end(); - reject(err); - }); - }); + try { + await pipelineAsync(proxyRes, adapter, transformer); + req.log.debug({ key: hash }, `Finished proxying SSE stream.`); + res.end(); + return aggregator.getFinalResponse(); + } catch (err) { + const errorEvent = buildFakeSseMessage("stream-error", err.message, req); + res.write(`${errorEvent}data: [DONE]\n\n`); + res.end(); + throw err; + } }; - -type SSETransformationArgs = { - data: string; - requestApi: string; - responseApi: string; - lastPosition: number; - index: number; -}; - -/** - * Transforms SSE events from the given response API into events compatible with - * the API requested by the client. - */ -function transformEvent(params: SSETransformationArgs) { - const { data, requestApi, responseApi } = params; - if (requestApi === responseApi) { - return { position: -1, event: data }; - } - - const trans = `${requestApi}->${responseApi}`; - switch (trans) { - case "openai->openai-text": - return transformOpenAITextEventToOpenAIChat(params); - case "openai->anthropic": - // TODO: handle new anthropic streaming format - return transformV1AnthropicEventToOpenAI(params); - default: - throw new Error(`Unsupported streaming API transformation. ${trans}`); - } -} - -function transformOpenAITextEventToOpenAIChat(params: SSETransformationArgs) { - const { data, index } = params; - - if (!data.startsWith("data:")) return { position: -1, event: data }; - if (data.startsWith("data: [DONE]")) return { position: -1, event: data }; - - const event = JSON.parse(data.slice("data: ".length)); - - // The very first event must be a role assignment with no content. - - const createEvent = () => ({ - id: event.id, - object: "chat.completion.chunk", - created: event.created, - model: event.model, - choices: [ - { - message: { role: "", content: "" } as { - role?: string; - content: string; - }, - index: 0, - finish_reason: null, - }, - ], - }); - - let buffer = ""; - - if (index === 0) { - const initialEvent = createEvent(); - initialEvent.choices[0].message.role = "assistant"; - buffer = `data: ${JSON.stringify(initialEvent)}\n\n`; - } - - const newEvent = { - ...event, - choices: [ - { - ...event.choices[0], - delta: { content: event.choices[0].text }, - text: undefined, - }, - ], - }; - - buffer += `data: ${JSON.stringify(newEvent)}`; - - return { position: -1, event: buffer }; -} - -function transformV1AnthropicEventToOpenAI(params: SSETransformationArgs) { - const { data, lastPosition } = params; - // Anthropic sends the full completion so far with each event whereas OpenAI - // only sends the delta. To make the SSE events compatible, we remove - // everything before `lastPosition` from the completion. - if (!data.startsWith("data:")) { - return { position: lastPosition, event: data }; - } - - if (data.startsWith("data: [DONE]")) { - return { position: lastPosition, event: data }; - } - - const event = JSON.parse(data.slice("data: ".length)); - const newEvent = { - id: "ant-" + event.log_id, - object: "chat.completion.chunk", - created: Date.now(), - model: event.model, - choices: [ - { - index: 0, - delta: { content: event.completion?.slice(lastPosition) }, - finish_reason: event.stop_reason, - }, - ], - }; - return { - position: event.completion.length, - event: `data: ${JSON.stringify(newEvent)}`, - }; -} - -/** Copy headers, excluding ones we're already setting for the SSE response. */ -function copyHeaders(proxyRes: http.IncomingMessage, res: Response) { - const toOmit = [ - "content-length", - "content-encoding", - "transfer-encoding", - "content-type", - "connection", - "cache-control", - ]; - for (const [key, value] of Object.entries(proxyRes.headers)) { - if (!toOmit.includes(key) && value) { - res.setHeader(key, value); - } - } -} - -/** - * Converts the list of incremental SSE events into an object that resembles a - * full, non-streamed response from the API so that subsequent middleware can - * operate on it as if it were a normal response. - * Events are expected to be in the format they were received from the API. - */ -function convertEventsToFinalResponse(events: string[], req: Request) { - switch (req.outboundApi) { - case "openai": { - let merged: OpenAiChatCompletionResponse = { - id: "", - object: "", - created: 0, - model: "", - choices: [], - }; - merged = events.reduce((acc, event, i) => { - if (!event.startsWith("data: ")) return acc; - if (event === "data: [DONE]") return acc; - - const data = JSON.parse(event.slice("data: ".length)); - - // The first chat chunk only contains the role assignment and metadata - if (i === 0) { - return { - id: data.id, - object: data.object, - created: data.created, - model: data.model, - choices: [ - { - message: { role: data.choices[0].delta.role, content: "" }, - index: 0, - finish_reason: null, - }, - ], - }; - } - - if (data.choices[0].delta.content) { - acc.choices[0].message.content += data.choices[0].delta.content; - } - acc.choices[0].finish_reason = data.choices[0].finish_reason; - return acc; - }, merged); - return merged; - } - case "openai-text": { - let merged: OpenAiTextCompletionResponse = { - id: "", - object: "", - created: 0, - model: "", - choices: [], - // TODO: merge logprobs - }; - merged = events.reduce((acc, event) => { - if (!event.startsWith("data: ")) return acc; - if (event === "data: [DONE]") return acc; - - const data = JSON.parse(event.slice("data: ".length)); - - return { - id: data.id, - object: data.object, - created: data.created, - model: data.model, - choices: [ - { - text: acc.choices[0]?.text + data.choices[0].text, - index: 0, - finish_reason: data.choices[0].finish_reason, - logprobs: null, - }, - ], - }; - }, merged); - return merged; - } - case "anthropic": { - if (req.headers["anthropic-version"] === "2023-01-01") { - return convertAnthropicV1(events, req); - } - - let merged: AnthropicCompletionResponse = { - completion: "", - stop_reason: "", - truncated: false, - stop: null, - model: req.body.model, - log_id: "", - exception: null, - } - - merged = events.reduce((acc, event) => { - if (!event.startsWith("data: ")) return acc; - if (event === "data: [DONE]") return acc; - - const data = JSON.parse(event.slice("data: ".length)); - - return { - completion: acc.completion + data.completion, - stop_reason: data.stop_reason, - truncated: data.truncated, - stop: data.stop, - log_id: data.log_id, - exception: data.exception, - model: acc.model, - }; - }, merged); - return merged; - } - case "google-palm": { - throw new Error("PaLM streaming not yet supported."); - } - default: - assertNever(req.outboundApi); - } -} - -/** Older Anthropic streaming format which sent full completion each time. */ -function convertAnthropicV1( - events: string[], - req: Request -) { - const lastEvent = events[events.length - 2].toString(); - const data = JSON.parse( - lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length) - ); - const final: AnthropicCompletionResponse = { ...data, log_id: req.id }; - return final; -} diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index 530a6b5..c79f3df 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -4,13 +4,16 @@ import * as http from "http"; import util from "util"; import zlib from "zlib"; import { logger } from "../../../logger"; +import { enqueue, trackWaitTime } from "../../queue"; +import { HttpError } from "../../../shared/errors"; import { keyPool } from "../../../shared/key-management"; import { getOpenAIModelFamily } from "../../../shared/models"; -import { enqueue, trackWaitTime } from "../../queue"; +import { countTokens } from "../../../shared/tokenization"; import { incrementPromptCount, incrementTokenCount, } from "../../../shared/users/user-store"; +import { assertNever } from "../../../shared/utils"; import { getCompletionFromBody, isCompletionRequest, @@ -18,8 +21,6 @@ import { } from "../common"; import { handleStreamedResponse } from "./handle-streamed-response"; import { logPrompt } from "./log-prompt"; -import { countTokens } from "../../../shared/tokenization"; -import { assertNever } from "../../../shared/utils"; const DECODER_MAP = { gzip: util.promisify(zlib.gunzip), @@ -83,7 +84,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => { ? handleStreamedResponse : decodeResponseBody; - let lastMiddlewareName = initialHandler.name; + let lastMiddleware = initialHandler.name; try { const body = await initialHandler(proxyRes, req, res); @@ -112,37 +113,38 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => { } for (const middleware of middlewareStack) { - lastMiddlewareName = middleware.name; + lastMiddleware = middleware.name; await middleware(proxyRes, req, res, body); } trackWaitTime(req); - } catch (error: any) { + } catch (error) { // Hack: if the error is a retryable rate-limit error, the request has // been re-enqueued and we can just return without doing anything else. if (error instanceof RetryableError) { return; } - const errorData = { - error: error.stack, - thrownBy: lastMiddlewareName, - key: req.key?.hash, - }; - const message = `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`; - if (res.headersSent) { - req.log.error(errorData, message); - // This should have already been handled by the error handler, but - // just in case... - if (!res.writableEnded) { - res.end(); - } + // Already logged and responded to the client by handleUpstreamErrors + if (error instanceof HttpError) { + if (!res.writableEnded) res.end(); return; } - logger.error(errorData, message); - res - .status(500) - .json({ error: "Internal server error", proxy_note: message }); + + const { stack, message } = error; + const info = { stack, lastMiddleware, key: req.key?.hash }; + const description = `Error while executing proxy response middleware: ${lastMiddleware} (${message})`; + + if (res.headersSent) { + req.log.error(info, description); + if (!res.writableEnded) res.end(); + return; + } else { + req.log.error(info, description); + res + .status(500) + .json({ error: "Internal server error", proxy_note: description }); + } } }; }; @@ -203,7 +205,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( return resolve(body.toString()); } catch (error: any) { const errorMessage = `Proxy received response with invalid JSON: ${error.message}`; - logger.warn({ error, key: req.key?.hash }, errorMessage); + logger.warn({ error: error.stack, key: req.key?.hash }, errorMessage); writeErrorResponse(req, res, 500, { error: errorMessage }); return reject(errorMessage); } @@ -223,7 +225,7 @@ type ProxiedErrorPayload = { * an error to stop the middleware stack. * On 429 errors, if request queueing is enabled, the request will be silently * re-enqueued. Otherwise, the request will be rejected with an error payload. - * @throws {Error} On HTTP error status code from upstream service + * @throws {HttpError} On HTTP error status code from upstream service */ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( proxyRes, @@ -258,7 +260,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( proxy_note: `This is likely a temporary error with the upstream service.`, }; writeErrorResponse(req, res, statusCode, errorObject); - throw new Error(parseError.message); + throw new HttpError(statusCode, parseError.message); } const errorType = @@ -371,7 +373,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( } writeErrorResponse(req, res, statusCode, errorPayload); - throw new Error(errorPayload.error?.message); + throw new HttpError(statusCode, errorPayload.error?.message); }; /** diff --git a/src/proxy/middleware/response/streaming/aggregators/anthropic.ts b/src/proxy/middleware/response/streaming/aggregators/anthropic.ts new file mode 100644 index 0000000..56f229e --- /dev/null +++ b/src/proxy/middleware/response/streaming/aggregators/anthropic.ts @@ -0,0 +1,48 @@ +import { OpenAIChatCompletionStreamEvent } from "../index"; + +export type AnthropicCompletionResponse = { + completion: string; + stop_reason: string; + truncated: boolean; + stop: any; + model: string; + log_id: string; + exception: null; +}; + +/** + * Given a list of OpenAI chat completion events, compiles them into a single + * finalized Anthropic completion response so that non-streaming middleware + * can operate on it as if it were a blocking response. + */ +export function mergeEventsForAnthropic( + events: OpenAIChatCompletionStreamEvent[] +): AnthropicCompletionResponse { + let merged: AnthropicCompletionResponse = { + log_id: "", + exception: null, + model: "", + completion: "", + stop_reason: "", + truncated: false, + stop: null, + }; + merged = events.reduce((acc, event, i) => { + // The first event will only contain role assignment and response metadata + if (i === 0) { + acc.log_id = event.id; + acc.model = event.model; + acc.completion = ""; + acc.stop_reason = ""; + return acc; + } + + acc.stop_reason = event.choices[0].finish_reason ?? ""; + if (event.choices[0].delta.content) { + acc.completion += event.choices[0].delta.content; + } + + return acc; + }, merged); + return merged; +} diff --git a/src/proxy/middleware/response/streaming/aggregators/openai-chat.ts b/src/proxy/middleware/response/streaming/aggregators/openai-chat.ts new file mode 100644 index 0000000..f1a1bd4 --- /dev/null +++ b/src/proxy/middleware/response/streaming/aggregators/openai-chat.ts @@ -0,0 +1,58 @@ +import { OpenAIChatCompletionStreamEvent } from "../index"; + +export type OpenAiChatCompletionResponse = { + id: string; + object: string; + created: number; + model: string; + choices: { + message: { role: string; content: string }; + finish_reason: string | null; + index: number; + }[]; +}; + +/** + * Given a list of OpenAI chat completion events, compiles them into a single + * finalized OpenAI chat completion response so that non-streaming middleware + * can operate on it as if it were a blocking response. + */ +export function mergeEventsForOpenAIChat( + events: OpenAIChatCompletionStreamEvent[] +): OpenAiChatCompletionResponse { + let merged: OpenAiChatCompletionResponse = { + id: "", + object: "", + created: 0, + model: "", + choices: [], + }; + merged = events.reduce((acc, event, i) => { + // The first event will only contain role assignment and response metadata + if (i === 0) { + acc.id = event.id; + acc.object = event.object; + acc.created = event.created; + acc.model = event.model; + acc.choices = [ + { + index: 0, + message: { + role: event.choices[0].delta.role ?? "assistant", + content: "", + }, + finish_reason: null, + }, + ]; + return acc; + } + + acc.choices[0].finish_reason = event.choices[0].finish_reason; + if (event.choices[0].delta.content) { + acc.choices[0].message.content += event.choices[0].delta.content; + } + + return acc; + }, merged); + return merged; +} diff --git a/src/proxy/middleware/response/streaming/aggregators/openai-text.ts b/src/proxy/middleware/response/streaming/aggregators/openai-text.ts new file mode 100644 index 0000000..f343934 --- /dev/null +++ b/src/proxy/middleware/response/streaming/aggregators/openai-text.ts @@ -0,0 +1,57 @@ +import { OpenAIChatCompletionStreamEvent } from "../index"; + +export type OpenAiTextCompletionResponse = { + id: string; + object: string; + created: number; + model: string; + choices: { + text: string; + finish_reason: string | null; + index: number; + logprobs: null; + }[]; +}; + +/** + * Given a list of OpenAI chat completion events, compiles them into a single + * finalized OpenAI text completion response so that non-streaming middleware + * can operate on it as if it were a blocking response. + */ +export function mergeEventsForOpenAIText( + events: OpenAIChatCompletionStreamEvent[] +): OpenAiTextCompletionResponse { + let merged: OpenAiTextCompletionResponse = { + id: "", + object: "", + created: 0, + model: "", + choices: [], + }; + merged = events.reduce((acc, event, i) => { + // The first event will only contain role assignment and response metadata + if (i === 0) { + acc.id = event.id; + acc.object = event.object; + acc.created = event.created; + acc.model = event.model; + acc.choices = [ + { + text: "", + index: 0, + finish_reason: null, + logprobs: null, + }, + ]; + return acc; + } + + acc.choices[0].finish_reason = event.choices[0].finish_reason; + if (event.choices[0].delta.content) { + acc.choices[0].text += event.choices[0].delta.content; + } + + return acc; + }, merged); + return merged; +} diff --git a/src/proxy/middleware/response/streaming/event-aggregator.ts b/src/proxy/middleware/response/streaming/event-aggregator.ts new file mode 100644 index 0000000..55f0fb3 --- /dev/null +++ b/src/proxy/middleware/response/streaming/event-aggregator.ts @@ -0,0 +1,41 @@ +import { APIFormat } from "../../../../shared/key-management"; +import { assertNever } from "../../../../shared/utils"; +import { + mergeEventsForAnthropic, + mergeEventsForOpenAIChat, + mergeEventsForOpenAIText, + OpenAIChatCompletionStreamEvent +} from "./index"; + +/** + * Collects SSE events containing incremental chat completion responses and + * compiles them into a single finalized response for downstream middleware. + */ +export class EventAggregator { + private readonly format: APIFormat; + private readonly events: OpenAIChatCompletionStreamEvent[]; + + constructor({ format }: { format: APIFormat }) { + this.events = []; + this.format = format; + } + + addEvent(event: OpenAIChatCompletionStreamEvent) { + this.events.push(event); + } + + getFinalResponse() { + switch (this.format) { + case "openai": + return mergeEventsForOpenAIChat(this.events); + case "openai-text": + return mergeEventsForOpenAIText(this.events); + case "anthropic": + return mergeEventsForAnthropic(this.events); + case "google-palm": + throw new Error("Google PaLM API does not support streaming responses"); + default: + assertNever(this.format); + } + } +} \ No newline at end of file diff --git a/src/proxy/middleware/response/streaming/index.ts b/src/proxy/middleware/response/streaming/index.ts new file mode 100644 index 0000000..6e232da --- /dev/null +++ b/src/proxy/middleware/response/streaming/index.ts @@ -0,0 +1,31 @@ +export type SSEResponseTransformArgs = { + data: string; + lastPosition: number; + index: number; + fallbackId: string; + fallbackModel: string; +}; + +export type OpenAIChatCompletionStreamEvent = { + id: string; + object: "chat.completion.chunk"; + created: number; + model: string; + choices: { + index: number; + delta: { role?: string; content?: string }; + finish_reason: string | null; + }[]; +} + +export type StreamingCompletionTransformer = ( + params: SSEResponseTransformArgs +) => { position: number; event?: OpenAIChatCompletionStreamEvent }; + +export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai"; +export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai"; +export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai"; +export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat"; +export { mergeEventsForOpenAIText } from "./aggregators/openai-text"; +export { mergeEventsForAnthropic } from "./aggregators/anthropic"; + diff --git a/src/proxy/middleware/response/streaming/parse-sse.ts b/src/proxy/middleware/response/streaming/parse-sse.ts new file mode 100644 index 0000000..d161fb5 --- /dev/null +++ b/src/proxy/middleware/response/streaming/parse-sse.ts @@ -0,0 +1,29 @@ +export type ServerSentEvent = { id?: string; type?: string; data: string }; + +/** Given a string of SSE data, parse it into a `ServerSentEvent` object. */ +export function parseEvent(event: string) { + const buffer: ServerSentEvent = { data: "" }; + return event.split(/\r?\n/).reduce(parseLine, buffer) +} + +function parseLine(event: ServerSentEvent, line: string) { + const separator = line.indexOf(":"); + const field = separator === -1 ? line : line.slice(0,separator); + const value = separator === -1 ? "" : line.slice(separator + 1); + + switch (field) { + case 'id': + event.id = value.trim() + break + case 'event': + event.type = value.trim() + break + case 'data': + event.data += value.trimStart() + break + default: + break + } + + return event +} \ No newline at end of file diff --git a/src/proxy/middleware/response/streaming/sse-message-transformer.ts b/src/proxy/middleware/response/streaming/sse-message-transformer.ts new file mode 100644 index 0000000..33d2882 --- /dev/null +++ b/src/proxy/middleware/response/streaming/sse-message-transformer.ts @@ -0,0 +1,123 @@ +import { Transform, TransformOptions } from "stream"; +import { logger } from "../../../../logger"; +import { APIFormat } from "../../../../shared/key-management"; +import { assertNever } from "../../../../shared/utils"; +import { + anthropicV1ToOpenAI, + anthropicV2ToOpenAI, + OpenAIChatCompletionStreamEvent, + openAITextToOpenAIChat, + StreamingCompletionTransformer, +} from "./index"; +import { passthroughToOpenAI } from "./transformers/passthrough-to-openai"; + +const genlog = logger.child({ module: "sse-transformer" }); + +type SSEMessageTransformerOptions = TransformOptions & { + requestedModel: string; + requestId: string; + inputFormat: APIFormat; + inputApiVersion?: string; + logger?: typeof logger; +}; + +/** + * Transforms SSE messages from one API format to OpenAI chat.completion.chunks. + * Emits the original string SSE message as an "originalMessage" event. + */ +export class SSEMessageTransformer extends Transform { + private lastPosition: number; + private msgCount: number; + private readonly transformFn: StreamingCompletionTransformer; + private readonly log; + private readonly fallbackId: string; + private readonly fallbackModel: string; + + constructor(options: SSEMessageTransformerOptions) { + super({ ...options, readableObjectMode: true }); + this.log = options.logger?.child({ module: "sse-transformer" }) ?? genlog; + this.lastPosition = 0; + this.msgCount = 0; + this.transformFn = getTransformer( + options.inputFormat, + options.inputApiVersion + ); + this.fallbackId = options.requestId; + this.fallbackModel = options.requestedModel; + this.log.debug( + { + fn: this.transformFn.name, + format: options.inputFormat, + version: options.inputApiVersion, + }, + "Selected SSE transformer" + ); + } + + _transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) { + try { + const originalMessage = chunk.toString(); + const { event: transformedMessage, position: newPosition } = + this.transformFn({ + data: originalMessage, + lastPosition: this.lastPosition, + index: this.msgCount++, + fallbackId: this.fallbackId, + fallbackModel: this.fallbackModel, + }); + this.lastPosition = newPosition; + + this.emit("originalMessage", originalMessage); + + // Some events may not be transformed, e.g. ping events + if (!transformedMessage) return callback(); + + if (this.msgCount === 1) { + this.push(createInitialMessage(transformedMessage)); + } + this.push(transformedMessage); + callback(); + } catch (err) { + this.log.error(err, "Error transforming SSE message"); + callback(err); + } + } +} + +function getTransformer( + responseApi: APIFormat, + version?: string +): StreamingCompletionTransformer { + switch (responseApi) { + case "openai": + return passthroughToOpenAI; + case "openai-text": + return openAITextToOpenAIChat; + case "anthropic": + return version === "2023-01-01" + ? anthropicV1ToOpenAI + : anthropicV2ToOpenAI; + case "google-palm": + throw new Error("Google PaLM does not support streaming responses"); + default: + assertNever(responseApi); + } +} + +/** + * OpenAI streaming chat completions start with an event that contains only the + * metadata and role (always 'assistant') for the response. To simulate this + * for APIs where the first event contains actual content, we create a fake + * initial event with no content but correct metadata. + */ +function createInitialMessage( + event: OpenAIChatCompletionStreamEvent +): OpenAIChatCompletionStreamEvent { + return { + ...event, + choices: event.choices.map((choice) => ({ + ...choice, + delta: { role: "assistant", content: "" }, + })), + }; +} diff --git a/src/proxy/middleware/response/sse-stream-adapter.ts b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts similarity index 66% rename from src/proxy/middleware/response/sse-stream-adapter.ts rename to src/proxy/middleware/response/streaming/sse-stream-adapter.ts index bbe473c..7ae1996 100644 --- a/src/proxy/middleware/response/sse-stream-adapter.ts +++ b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts @@ -1,11 +1,11 @@ import { Transform, TransformOptions } from "stream"; // @ts-ignore import { Parser } from "lifion-aws-event-stream"; -import { logger } from "../../../logger"; +import { logger } from "../../../../logger"; const log = logger.child({ module: "sse-stream-adapter" }); -type SSEStreamAdapterOptions = TransformOptions & { isAwsStream?: boolean }; +type SSEStreamAdapterOptions = TransformOptions & { contentType?: string }; type AwsEventStreamMessage = { headers: { ":message-type": "event" | "exception" }; payload: { message?: string /** base64 encoded */; bytes?: string }; @@ -15,24 +15,25 @@ type AwsEventStreamMessage = { * Receives either text chunks or AWS binary event stream chunks and emits * full SSE events. */ -export class ServerSentEventStreamAdapter extends Transform { +export class SSEStreamAdapter extends Transform { private readonly isAwsStream; private parser = new Parser(); private partialMessage = ""; constructor(options?: SSEStreamAdapterOptions) { super(options); - this.isAwsStream = options?.isAwsStream || false; + this.isAwsStream = + options?.contentType === "application/vnd.amazon.eventstream"; this.parser.on("data", (data: AwsEventStreamMessage) => { const message = this.processAwsEvent(data); if (message) { - this.push(Buffer.from(message, "utf8")); + this.push(Buffer.from(message + "\n\n"), "utf8"); } }); } - processAwsEvent(event: AwsEventStreamMessage): string | null { + protected processAwsEvent(event: AwsEventStreamMessage): string | null { const { payload, headers } = event; if (headers[":message-type"] === "exception" || !payload.bytes) { log.error( @@ -42,7 +43,14 @@ export class ServerSentEventStreamAdapter extends Transform { const message = JSON.stringify(event); return getFakeErrorCompletion("proxy AWS error", message); } else { - return `data: ${Buffer.from(payload.bytes, "base64").toString("utf8")}`; + const { bytes } = payload; + // technically this is a transformation but we don't really distinguish + // between aws claude and anthropic claude at the APIFormat level, so + // these will short circuit the message transformer + return [ + "event: completion", + `data: ${Buffer.from(bytes, "base64").toString("utf8")}`, + ].join("\n"); } } @@ -55,11 +63,15 @@ export class ServerSentEventStreamAdapter extends Transform { // so we need to buffer and emit separate stream events for full // messages so we can parse/transform them properly. const str = chunk.toString("utf8"); + const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/); this.partialMessage = fullMessages.pop() || ""; for (const message of fullMessages) { - this.push(message); + // Mixing line endings will break some clients and our request queue + // will have already sent \n for heartbeats, so we need to normalize + // to \n. + this.push(message.replace(/\r\n/g, "\n") + "\n\n"); } } callback(); @@ -72,7 +84,7 @@ export class ServerSentEventStreamAdapter extends Transform { function getFakeErrorCompletion(type: string, message: string) { const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`; - const fakeEvent = { + const fakeEvent = JSON.stringify({ log_id: "aws-proxy-sse-message", stop_reason: type, completion: @@ -80,6 +92,6 @@ function getFakeErrorCompletion(type: string, message: string) { truncated: false, stop: null, model: "", - }; - return `data: ${JSON.stringify(fakeEvent)}\n\n`; + }); + return ["event: completion", `data: ${fakeEvent}\n\n`].join("\n"); } diff --git a/src/proxy/middleware/response/streaming/transformers/anthropic-v1-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/anthropic-v1-to-openai.ts new file mode 100644 index 0000000..f145290 --- /dev/null +++ b/src/proxy/middleware/response/streaming/transformers/anthropic-v1-to-openai.ts @@ -0,0 +1,67 @@ +import { StreamingCompletionTransformer } from "../index"; +import { parseEvent, ServerSentEvent } from "../parse-sse"; +import { logger } from "../../../../../logger"; + +const log = logger.child({ + module: "sse-transformer", + transformer: "anthropic-v1-to-openai", +}); + +type AnthropicV1StreamEvent = { + log_id?: string; + model?: string; + completion: string; + stop_reason: string; +}; + +/** + * Transforms an incoming Anthropic SSE (2023-01-01 API) to an equivalent + * OpenAI chat.completion.chunk SSE. + */ +export const anthropicV1ToOpenAI: StreamingCompletionTransformer = (params) => { + const { data, lastPosition } = params; + + const rawEvent = parseEvent(data); + if (!rawEvent.data || rawEvent.data === "[DONE]") { + return { position: lastPosition }; + } + + const completionEvent = asCompletion(rawEvent); + if (!completionEvent) { + return { position: lastPosition }; + } + + // Anthropic sends the full completion so far with each event whereas OpenAI + // only sends the delta. To make the SSE events compatible, we remove + // everything before `lastPosition` from the completion. + const newEvent = { + id: "ant-" + (completionEvent.log_id ?? params.fallbackId), + object: "chat.completion.chunk" as const, + created: Date.now(), + model: completionEvent.model ?? params.fallbackModel, + choices: [ + { + index: 0, + delta: { content: completionEvent.completion?.slice(lastPosition) }, + finish_reason: completionEvent.stop_reason, + }, + ], + }; + + return { position: completionEvent.completion.length, event: newEvent }; +}; + +function asCompletion(event: ServerSentEvent): AnthropicV1StreamEvent | null { + try { + const parsed = JSON.parse(event.data); + if (parsed.completion !== undefined && parsed.stop_reason !== undefined) { + return parsed; + } else { + // noinspection ExceptionCaughtLocallyJS + throw new Error("Missing required fields"); + } + } catch (error) { + log.warn({ error: error.stack, event }, "Received invalid event"); + } + return null; +} diff --git a/src/proxy/middleware/response/streaming/transformers/anthropic-v2-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/anthropic-v2-to-openai.ts new file mode 100644 index 0000000..9db1c0f --- /dev/null +++ b/src/proxy/middleware/response/streaming/transformers/anthropic-v2-to-openai.ts @@ -0,0 +1,66 @@ +import { StreamingCompletionTransformer } from "../index"; +import { parseEvent, ServerSentEvent } from "../parse-sse"; +import { logger } from "../../../../../logger"; + +const log = logger.child({ + module: "sse-transformer", + transformer: "anthropic-v2-to-openai", +}); + +type AnthropicV2StreamEvent = { + log_id?: string; + model?: string; + completion: string; + stop_reason: string; +}; + +/** + * Transforms an incoming Anthropic SSE (2023-06-01 API) to an equivalent + * OpenAI chat.completion.chunk SSE. + */ +export const anthropicV2ToOpenAI: StreamingCompletionTransformer = (params) => { + const { data } = params; + + const rawEvent = parseEvent(data); + if (!rawEvent.data || rawEvent.data === "[DONE]") { + return { position: -1 }; + } + + const completionEvent = asCompletion(rawEvent); + if (!completionEvent) { + return { position: -1 }; + } + + const newEvent = { + id: "ant-" + (completionEvent.log_id ?? params.fallbackId), + object: "chat.completion.chunk" as const, + created: Date.now(), + model: completionEvent.model ?? params.fallbackModel, + choices: [ + { + index: 0, + delta: { content: completionEvent.completion }, + finish_reason: completionEvent.stop_reason, + }, + ], + }; + + return { position: completionEvent.completion.length, event: newEvent }; +}; + +function asCompletion(event: ServerSentEvent): AnthropicV2StreamEvent | null { + if (event.type === "ping") return null; + + try { + const parsed = JSON.parse(event.data); + if (parsed.completion !== undefined && parsed.stop_reason !== undefined) { + return parsed; + } else { + // noinspection ExceptionCaughtLocallyJS + throw new Error("Missing required fields"); + } + } catch (error) { + log.warn({ error: error.stack, event }, "Received invalid event"); + } + return null; +} diff --git a/src/proxy/middleware/response/streaming/transformers/openai-text-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/openai-text-to-openai.ts new file mode 100644 index 0000000..c85795a --- /dev/null +++ b/src/proxy/middleware/response/streaming/transformers/openai-text-to-openai.ts @@ -0,0 +1,68 @@ +import { SSEResponseTransformArgs } from "../index"; +import { parseEvent, ServerSentEvent } from "../parse-sse"; +import { logger } from "../../../../../logger"; + +const log = logger.child({ + module: "sse-transformer", + transformer: "openai-text-to-openai", +}); + +type OpenAITextCompletionStreamEvent = { + id: string; + object: "text_completion"; + created: number; + choices: { + text: string; + index: number; + logprobs: null; + finish_reason: string | null; + }[]; + model: string; +}; + +export const openAITextToOpenAIChat = (params: SSEResponseTransformArgs) => { + const { data } = params; + + const rawEvent = parseEvent(data); + if (!rawEvent.data || rawEvent.data === "[DONE]") { + return { position: -1 }; + } + + const completionEvent = asCompletion(rawEvent); + if (!completionEvent) { + return { position: -1 }; + } + + const newEvent = { + id: completionEvent.id, + object: "chat.completion.chunk" as const, + created: completionEvent.created, + model: completionEvent.model, + choices: [ + { + index: completionEvent.choices[0].index, + delta: { content: completionEvent.choices[0].text }, + finish_reason: completionEvent.choices[0].finish_reason, + }, + ], + }; + + return { position: -1, event: newEvent }; +}; + +function asCompletion( + event: ServerSentEvent +): OpenAITextCompletionStreamEvent | null { + try { + const parsed = JSON.parse(event.data); + if (Array.isArray(parsed.choices) && parsed.choices[0].text !== undefined) { + return parsed; + } else { + // noinspection ExceptionCaughtLocallyJS + throw new Error("Missing required fields"); + } + } catch (error) { + log.warn({ error: error.stack, event }, "Received invalid data event"); + } + return null; +} diff --git a/src/proxy/middleware/response/streaming/transformers/passthrough-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/passthrough-to-openai.ts new file mode 100644 index 0000000..2edcf60 --- /dev/null +++ b/src/proxy/middleware/response/streaming/transformers/passthrough-to-openai.ts @@ -0,0 +1,38 @@ +import { + OpenAIChatCompletionStreamEvent, + SSEResponseTransformArgs, +} from "../index"; +import { parseEvent, ServerSentEvent } from "../parse-sse"; +import { logger } from "../../../../../logger"; + +const log = logger.child({ + module: "sse-transformer", + transformer: "openai-to-openai", +}); + +export const passthroughToOpenAI = (params: SSEResponseTransformArgs) => { + const { data } = params; + + const rawEvent = parseEvent(data); + if (!rawEvent.data || rawEvent.data === "[DONE]") { + return { position: -1 }; + } + + const completionEvent = asCompletion(rawEvent); + if (!completionEvent) { + return { position: -1 }; + } + + return { position: -1, event: completionEvent }; +}; + +function asCompletion( + event: ServerSentEvent +): OpenAIChatCompletionStreamEvent | null { + try { + return JSON.parse(event.data); + } catch (error) { + log.warn({ error: error.stack, event }, "Received invalid event"); + } + return null; +} diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index fc98ad3..66aa37c 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -23,10 +23,11 @@ import { getOpenAIModelFamily, ModelFamily, } from "../shared/models"; +import { initializeSseStream } from "../shared/streaming"; +import { assertNever } from "../shared/utils"; import { logger } from "../logger"; import { AGNAI_DOT_CHAT_IP } from "./rate-limit"; import { buildFakeSseMessage } from "./middleware/common"; -import { assertNever } from "../shared/utils"; const queue: Request[] = []; const log = logger.child({ module: "request-queue" }); @@ -352,14 +353,8 @@ function killQueuedRequest(req: Request) { } function initStreaming(req: Request) { - req.log.info(`Initiating streaming for new queued request.`); const res = req.res!; - res.statusCode = 200; - res.setHeader("Content-Type", "text/event-stream"); - res.setHeader("Cache-Control", "no-cache"); - res.setHeader("Connection", "keep-alive"); - res.setHeader("X-Accel-Buffering", "no"); // nginx-specific fix - res.flushHeaders(); + initializeSseStream(res); if (req.query.badSseParser) { // Some clients have a broken SSE parser that doesn't handle comments @@ -368,7 +363,6 @@ function initStreaming(req: Request) { return; } - res.write("\n"); res.write(": joining queue\n\n"); } diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index a2ffda6..424f104 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -7,6 +7,14 @@ import { googlePalm } from "./palm"; import { aws } from "./aws"; const proxyRouter = express.Router(); +proxyRouter.use((req, _res, next) => { + if (req.headers.expect) { + // node-http-proxy does not like it when clients send `expect: 100-continue` + // and will stall. none of the upstream APIs use this header anyway. + delete req.headers.expect; + } + next(); +}); proxyRouter.use( express.json({ limit: "1536kb" }), express.urlencoded({ extended: true, limit: "1536kb" }) diff --git a/src/shared/streaming.ts b/src/shared/streaming.ts new file mode 100644 index 0000000..60b8455 --- /dev/null +++ b/src/shared/streaming.ts @@ -0,0 +1,34 @@ +import { Response } from "express"; +import { IncomingMessage } from "http"; + +export function initializeSseStream(res: Response) { + res.statusCode = 200; + res.setHeader("Content-Type", "text/event-stream; charset=utf-8"); + res.setHeader("Cache-Control", "no-cache"); + res.setHeader("Connection", "keep-alive"); + res.setHeader("X-Accel-Buffering", "no"); // nginx-specific fix + res.flushHeaders(); +} + +/** + * Copies headers received from upstream API to the SSE response, excluding + * ones we need to set ourselves for SSE to work. + */ +export function copySseResponseHeaders( + proxyRes: IncomingMessage, + res: Response +) { + const toOmit = [ + "content-length", + "content-encoding", + "transfer-encoding", + "content-type", + "connection", + "cache-control", + ]; + for (const [key, value] of Object.entries(proxyRes.headers)) { + if (!toOmit.includes(key) && value) { + res.setHeader(key, value); + } + } +} diff --git a/src/shared/users/user-store.ts b/src/shared/users/user-store.ts index 3af078b..8924952 100644 --- a/src/shared/users/user-store.ts +++ b/src/shared/users/user-store.ts @@ -277,7 +277,7 @@ function cleanupExpiredTokens() { deleted++; } } - log.debug({ disabled, deleted }, "Expired tokens cleaned up."); + log.trace({ disabled, deleted }, "Expired tokens cleaned up."); } function refreshAllQuotas() {