Refactor handleStreamingResponse to make it less shit (khanon/oai-reverse-proxy!46)
This commit is contained in:
parent
6a3d753f0d
commit
ecf897e685
|
@ -42,7 +42,7 @@ export function writeErrorResponse(
|
||||||
// the stream. Otherwise just send a normal error response.
|
// the stream. Otherwise just send a normal error response.
|
||||||
if (
|
if (
|
||||||
res.headersSent ||
|
res.headersSent ||
|
||||||
res.getHeader("content-type") === "text/event-stream"
|
String(res.getHeader("content-type")).startsWith("text/event-stream")
|
||||||
) {
|
) {
|
||||||
const errorContent =
|
const errorContent =
|
||||||
statusCode === 403
|
statusCode === 403
|
||||||
|
|
|
@ -166,12 +166,7 @@ function openaiToAnthropic(req: Request) {
|
||||||
throw result.error;
|
throw result.error;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Anthropic has started versioning their API, indicated by an HTTP header
|
req.headers["anthropic-version"] = "2023-06-01";
|
||||||
// `anthropic-version`. The new June 2023 version is not backwards compatible
|
|
||||||
// with our OpenAI-to-Anthropic transformations so we need to explicitly
|
|
||||||
// request the older version for now. 2023-01-01 will be removed in September.
|
|
||||||
// https://docs.anthropic.com/claude/reference/versioning
|
|
||||||
req.headers["anthropic-version"] = "2023-01-01";
|
|
||||||
|
|
||||||
const { messages, ...rest } = result.data;
|
const { messages, ...rest } = result.data;
|
||||||
const prompt = openAIMessagesToClaudePrompt(messages);
|
const prompt = openAIMessagesToClaudePrompt(messages);
|
||||||
|
|
|
@ -1,44 +1,16 @@
|
||||||
import { Request, Response } from "express";
|
import { pipeline } from "stream";
|
||||||
import * as http from "http";
|
import { promisify } from "util";
|
||||||
import { buildFakeSseMessage } from "../common";
|
import { buildFakeSseMessage } from "../common";
|
||||||
import { RawResponseBodyHandler, decodeResponseBody } from ".";
|
import { decodeResponseBody, RawResponseBodyHandler } from ".";
|
||||||
import { assertNever } from "../../../shared/utils";
|
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||||
import { ServerSentEventStreamAdapter } from "./sse-stream-adapter";
|
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
|
||||||
|
import { EventAggregator } from "./streaming/event-aggregator";
|
||||||
|
import {
|
||||||
|
copySseResponseHeaders,
|
||||||
|
initializeSseStream,
|
||||||
|
} from "../../../shared/streaming";
|
||||||
|
|
||||||
type OpenAiChatCompletionResponse = {
|
const pipelineAsync = promisify(pipeline);
|
||||||
id: string;
|
|
||||||
object: string;
|
|
||||||
created: number;
|
|
||||||
model: string;
|
|
||||||
choices: {
|
|
||||||
message: { role: string; content: string };
|
|
||||||
finish_reason: string | null;
|
|
||||||
index: number;
|
|
||||||
}[];
|
|
||||||
};
|
|
||||||
|
|
||||||
type OpenAiTextCompletionResponse = {
|
|
||||||
id: string;
|
|
||||||
object: string;
|
|
||||||
created: number;
|
|
||||||
model: string;
|
|
||||||
choices: {
|
|
||||||
text: string;
|
|
||||||
finish_reason: string | null;
|
|
||||||
index: number;
|
|
||||||
logprobs: null;
|
|
||||||
}[];
|
|
||||||
};
|
|
||||||
|
|
||||||
type AnthropicCompletionResponse = {
|
|
||||||
completion: string;
|
|
||||||
stop_reason: string;
|
|
||||||
truncated: boolean;
|
|
||||||
stop: any;
|
|
||||||
model: string;
|
|
||||||
log_id: string;
|
|
||||||
exception: null;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Consume the SSE stream and forward events to the client. Once the stream is
|
* Consume the SSE stream and forward events to the client. Once the stream is
|
||||||
|
@ -49,370 +21,67 @@ type AnthropicCompletionResponse = {
|
||||||
* in the event a streamed request results in a non-200 response, we need to
|
* in the event a streamed request results in a non-200 response, we need to
|
||||||
* fall back to the non-streaming response handler so that the error handler
|
* fall back to the non-streaming response handler so that the error handler
|
||||||
* can inspect the error response.
|
* can inspect the error response.
|
||||||
*
|
|
||||||
* Currently most frontends don't support Anthropic streaming, so users can opt
|
|
||||||
* to send requests for Claude models via an endpoint that accepts OpenAI-
|
|
||||||
* compatible requests and translates the received Anthropic SSE events into
|
|
||||||
* OpenAI ones, essentially pretending to be an OpenAI streaming API.
|
|
||||||
*/
|
*/
|
||||||
export const handleStreamedResponse: RawResponseBodyHandler = async (
|
export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||||
proxyRes,
|
proxyRes,
|
||||||
req,
|
req,
|
||||||
res
|
res
|
||||||
) => {
|
) => {
|
||||||
// If these differ, the user is using the OpenAI-compatibile endpoint, so
|
const { hash } = req.key!;
|
||||||
// we need to translate the SSE events into OpenAI completion events for their
|
|
||||||
// frontend.
|
|
||||||
if (!req.isStreaming) {
|
if (!req.isStreaming) {
|
||||||
const err = new Error(
|
throw new Error("handleStreamedResponse called for non-streaming request.");
|
||||||
"handleStreamedResponse called for non-streaming request."
|
|
||||||
);
|
|
||||||
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
|
|
||||||
throw err;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const key = req.key!;
|
if (proxyRes.statusCode! > 201) {
|
||||||
if (proxyRes.statusCode !== 200) {
|
req.isStreaming = false; // Forces non-streaming response handler to execute
|
||||||
// Ensure we use the non-streaming middleware stack since we won't be
|
|
||||||
// getting any events.
|
|
||||||
req.isStreaming = false;
|
|
||||||
req.log.warn(
|
req.log.warn(
|
||||||
{ statusCode: proxyRes.statusCode, key: key.hash },
|
{ statusCode: proxyRes.statusCode, key: hash },
|
||||||
`Streaming request returned error status code. Falling back to non-streaming response handler.`
|
`Streaming request returned error status code. Falling back to non-streaming response handler.`
|
||||||
);
|
);
|
||||||
return decodeResponseBody(proxyRes, req, res);
|
return decodeResponseBody(proxyRes, req, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
req.log.debug(
|
req.log.debug(
|
||||||
{ headers: proxyRes.headers, key: key.hash },
|
{ headers: proxyRes.headers, key: hash },
|
||||||
`Received SSE headers.`
|
`Starting to proxy SSE stream.`
|
||||||
);
|
);
|
||||||
|
|
||||||
return new Promise((resolve, reject) => {
|
// Users waiting in the queue already have a SSE connection open for the
|
||||||
req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`);
|
// heartbeat, so we can't always send the stream headers.
|
||||||
|
if (!res.headersSent) {
|
||||||
|
copySseResponseHeaders(proxyRes, res);
|
||||||
|
initializeSseStream(res);
|
||||||
|
}
|
||||||
|
|
||||||
// Queued streaming requests will already have a connection open and headers
|
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
||||||
// sent due to the heartbeat handler. In that case we can just start
|
const contentType = proxyRes.headers["content-type"];
|
||||||
// streaming the response without sending headers.
|
|
||||||
if (!res.headersSent) {
|
|
||||||
res.setHeader("Content-Type", "text/event-stream");
|
|
||||||
res.setHeader("Cache-Control", "no-cache");
|
|
||||||
res.setHeader("Connection", "keep-alive");
|
|
||||||
res.setHeader("X-Accel-Buffering", "no");
|
|
||||||
copyHeaders(proxyRes, res);
|
|
||||||
res.flushHeaders();
|
|
||||||
}
|
|
||||||
|
|
||||||
const adapter = new ServerSentEventStreamAdapter({
|
const adapter = new SSEStreamAdapter({ contentType });
|
||||||
isAwsStream:
|
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||||
proxyRes.headers["content-type"] ===
|
const transformer = new SSEMessageTransformer({
|
||||||
"application/vnd.amazon.eventstream",
|
inputFormat: req.outboundApi, // outbound from the request's perspective
|
||||||
|
inputApiVersion: String(req.headers["anthropic-version"]),
|
||||||
|
logger: req.log,
|
||||||
|
requestId: String(req.id),
|
||||||
|
requestedModel: req.body.model,
|
||||||
|
})
|
||||||
|
.on("originalMessage", (msg: string) => {
|
||||||
|
if (prefersNativeEvents) res.write(msg);
|
||||||
|
})
|
||||||
|
.on("data", (msg) => {
|
||||||
|
if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`);
|
||||||
|
aggregator.addEvent(msg);
|
||||||
});
|
});
|
||||||
|
|
||||||
const events: string[] = [];
|
try {
|
||||||
let lastPosition = 0;
|
await pipelineAsync(proxyRes, adapter, transformer);
|
||||||
let eventCount = 0;
|
req.log.debug({ key: hash }, `Finished proxying SSE stream.`);
|
||||||
|
res.end();
|
||||||
proxyRes.pipe(adapter);
|
return aggregator.getFinalResponse();
|
||||||
|
} catch (err) {
|
||||||
adapter.on("data", (chunk: any) => {
|
const errorEvent = buildFakeSseMessage("stream-error", err.message, req);
|
||||||
try {
|
res.write(`${errorEvent}data: [DONE]\n\n`);
|
||||||
const { event, position } = transformEvent({
|
res.end();
|
||||||
data: chunk.toString(),
|
throw err;
|
||||||
requestApi: req.inboundApi,
|
}
|
||||||
responseApi: req.outboundApi,
|
|
||||||
lastPosition,
|
|
||||||
index: eventCount++,
|
|
||||||
});
|
|
||||||
events.push(event);
|
|
||||||
lastPosition = position;
|
|
||||||
res.write(event + "\n\n");
|
|
||||||
} catch (err) {
|
|
||||||
adapter.emit("error", err);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
adapter.on("end", () => {
|
|
||||||
try {
|
|
||||||
req.log.info({ key: key.hash }, `Finished proxying SSE stream.`);
|
|
||||||
const finalBody = convertEventsToFinalResponse(events, req);
|
|
||||||
res.end();
|
|
||||||
resolve(finalBody);
|
|
||||||
} catch (err) {
|
|
||||||
adapter.emit("error", err);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
adapter.on("error", (err) => {
|
|
||||||
req.log.error({ error: err, key: key.hash }, `Mid-stream error.`);
|
|
||||||
const errorEvent = buildFakeSseMessage("stream-error", err.message, req);
|
|
||||||
res.write(`data: ${JSON.stringify(errorEvent)}\n\ndata: [DONE]\n\n`);
|
|
||||||
res.end();
|
|
||||||
reject(err);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
type SSETransformationArgs = {
|
|
||||||
data: string;
|
|
||||||
requestApi: string;
|
|
||||||
responseApi: string;
|
|
||||||
lastPosition: number;
|
|
||||||
index: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Transforms SSE events from the given response API into events compatible with
|
|
||||||
* the API requested by the client.
|
|
||||||
*/
|
|
||||||
function transformEvent(params: SSETransformationArgs) {
|
|
||||||
const { data, requestApi, responseApi } = params;
|
|
||||||
if (requestApi === responseApi) {
|
|
||||||
return { position: -1, event: data };
|
|
||||||
}
|
|
||||||
|
|
||||||
const trans = `${requestApi}->${responseApi}`;
|
|
||||||
switch (trans) {
|
|
||||||
case "openai->openai-text":
|
|
||||||
return transformOpenAITextEventToOpenAIChat(params);
|
|
||||||
case "openai->anthropic":
|
|
||||||
// TODO: handle new anthropic streaming format
|
|
||||||
return transformV1AnthropicEventToOpenAI(params);
|
|
||||||
default:
|
|
||||||
throw new Error(`Unsupported streaming API transformation. ${trans}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function transformOpenAITextEventToOpenAIChat(params: SSETransformationArgs) {
|
|
||||||
const { data, index } = params;
|
|
||||||
|
|
||||||
if (!data.startsWith("data:")) return { position: -1, event: data };
|
|
||||||
if (data.startsWith("data: [DONE]")) return { position: -1, event: data };
|
|
||||||
|
|
||||||
const event = JSON.parse(data.slice("data: ".length));
|
|
||||||
|
|
||||||
// The very first event must be a role assignment with no content.
|
|
||||||
|
|
||||||
const createEvent = () => ({
|
|
||||||
id: event.id,
|
|
||||||
object: "chat.completion.chunk",
|
|
||||||
created: event.created,
|
|
||||||
model: event.model,
|
|
||||||
choices: [
|
|
||||||
{
|
|
||||||
message: { role: "", content: "" } as {
|
|
||||||
role?: string;
|
|
||||||
content: string;
|
|
||||||
},
|
|
||||||
index: 0,
|
|
||||||
finish_reason: null,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
let buffer = "";
|
|
||||||
|
|
||||||
if (index === 0) {
|
|
||||||
const initialEvent = createEvent();
|
|
||||||
initialEvent.choices[0].message.role = "assistant";
|
|
||||||
buffer = `data: ${JSON.stringify(initialEvent)}\n\n`;
|
|
||||||
}
|
|
||||||
|
|
||||||
const newEvent = {
|
|
||||||
...event,
|
|
||||||
choices: [
|
|
||||||
{
|
|
||||||
...event.choices[0],
|
|
||||||
delta: { content: event.choices[0].text },
|
|
||||||
text: undefined,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
};
|
|
||||||
|
|
||||||
buffer += `data: ${JSON.stringify(newEvent)}`;
|
|
||||||
|
|
||||||
return { position: -1, event: buffer };
|
|
||||||
}
|
|
||||||
|
|
||||||
function transformV1AnthropicEventToOpenAI(params: SSETransformationArgs) {
|
|
||||||
const { data, lastPosition } = params;
|
|
||||||
// Anthropic sends the full completion so far with each event whereas OpenAI
|
|
||||||
// only sends the delta. To make the SSE events compatible, we remove
|
|
||||||
// everything before `lastPosition` from the completion.
|
|
||||||
if (!data.startsWith("data:")) {
|
|
||||||
return { position: lastPosition, event: data };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.startsWith("data: [DONE]")) {
|
|
||||||
return { position: lastPosition, event: data };
|
|
||||||
}
|
|
||||||
|
|
||||||
const event = JSON.parse(data.slice("data: ".length));
|
|
||||||
const newEvent = {
|
|
||||||
id: "ant-" + event.log_id,
|
|
||||||
object: "chat.completion.chunk",
|
|
||||||
created: Date.now(),
|
|
||||||
model: event.model,
|
|
||||||
choices: [
|
|
||||||
{
|
|
||||||
index: 0,
|
|
||||||
delta: { content: event.completion?.slice(lastPosition) },
|
|
||||||
finish_reason: event.stop_reason,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
};
|
|
||||||
return {
|
|
||||||
position: event.completion.length,
|
|
||||||
event: `data: ${JSON.stringify(newEvent)}`,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Copy headers, excluding ones we're already setting for the SSE response. */
|
|
||||||
function copyHeaders(proxyRes: http.IncomingMessage, res: Response) {
|
|
||||||
const toOmit = [
|
|
||||||
"content-length",
|
|
||||||
"content-encoding",
|
|
||||||
"transfer-encoding",
|
|
||||||
"content-type",
|
|
||||||
"connection",
|
|
||||||
"cache-control",
|
|
||||||
];
|
|
||||||
for (const [key, value] of Object.entries(proxyRes.headers)) {
|
|
||||||
if (!toOmit.includes(key) && value) {
|
|
||||||
res.setHeader(key, value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts the list of incremental SSE events into an object that resembles a
|
|
||||||
* full, non-streamed response from the API so that subsequent middleware can
|
|
||||||
* operate on it as if it were a normal response.
|
|
||||||
* Events are expected to be in the format they were received from the API.
|
|
||||||
*/
|
|
||||||
function convertEventsToFinalResponse(events: string[], req: Request) {
|
|
||||||
switch (req.outboundApi) {
|
|
||||||
case "openai": {
|
|
||||||
let merged: OpenAiChatCompletionResponse = {
|
|
||||||
id: "",
|
|
||||||
object: "",
|
|
||||||
created: 0,
|
|
||||||
model: "",
|
|
||||||
choices: [],
|
|
||||||
};
|
|
||||||
merged = events.reduce((acc, event, i) => {
|
|
||||||
if (!event.startsWith("data: ")) return acc;
|
|
||||||
if (event === "data: [DONE]") return acc;
|
|
||||||
|
|
||||||
const data = JSON.parse(event.slice("data: ".length));
|
|
||||||
|
|
||||||
// The first chat chunk only contains the role assignment and metadata
|
|
||||||
if (i === 0) {
|
|
||||||
return {
|
|
||||||
id: data.id,
|
|
||||||
object: data.object,
|
|
||||||
created: data.created,
|
|
||||||
model: data.model,
|
|
||||||
choices: [
|
|
||||||
{
|
|
||||||
message: { role: data.choices[0].delta.role, content: "" },
|
|
||||||
index: 0,
|
|
||||||
finish_reason: null,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.choices[0].delta.content) {
|
|
||||||
acc.choices[0].message.content += data.choices[0].delta.content;
|
|
||||||
}
|
|
||||||
acc.choices[0].finish_reason = data.choices[0].finish_reason;
|
|
||||||
return acc;
|
|
||||||
}, merged);
|
|
||||||
return merged;
|
|
||||||
}
|
|
||||||
case "openai-text": {
|
|
||||||
let merged: OpenAiTextCompletionResponse = {
|
|
||||||
id: "",
|
|
||||||
object: "",
|
|
||||||
created: 0,
|
|
||||||
model: "",
|
|
||||||
choices: [],
|
|
||||||
// TODO: merge logprobs
|
|
||||||
};
|
|
||||||
merged = events.reduce((acc, event) => {
|
|
||||||
if (!event.startsWith("data: ")) return acc;
|
|
||||||
if (event === "data: [DONE]") return acc;
|
|
||||||
|
|
||||||
const data = JSON.parse(event.slice("data: ".length));
|
|
||||||
|
|
||||||
return {
|
|
||||||
id: data.id,
|
|
||||||
object: data.object,
|
|
||||||
created: data.created,
|
|
||||||
model: data.model,
|
|
||||||
choices: [
|
|
||||||
{
|
|
||||||
text: acc.choices[0]?.text + data.choices[0].text,
|
|
||||||
index: 0,
|
|
||||||
finish_reason: data.choices[0].finish_reason,
|
|
||||||
logprobs: null,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
};
|
|
||||||
}, merged);
|
|
||||||
return merged;
|
|
||||||
}
|
|
||||||
case "anthropic": {
|
|
||||||
if (req.headers["anthropic-version"] === "2023-01-01") {
|
|
||||||
return convertAnthropicV1(events, req);
|
|
||||||
}
|
|
||||||
|
|
||||||
let merged: AnthropicCompletionResponse = {
|
|
||||||
completion: "",
|
|
||||||
stop_reason: "",
|
|
||||||
truncated: false,
|
|
||||||
stop: null,
|
|
||||||
model: req.body.model,
|
|
||||||
log_id: "",
|
|
||||||
exception: null,
|
|
||||||
}
|
|
||||||
|
|
||||||
merged = events.reduce((acc, event) => {
|
|
||||||
if (!event.startsWith("data: ")) return acc;
|
|
||||||
if (event === "data: [DONE]") return acc;
|
|
||||||
|
|
||||||
const data = JSON.parse(event.slice("data: ".length));
|
|
||||||
|
|
||||||
return {
|
|
||||||
completion: acc.completion + data.completion,
|
|
||||||
stop_reason: data.stop_reason,
|
|
||||||
truncated: data.truncated,
|
|
||||||
stop: data.stop,
|
|
||||||
log_id: data.log_id,
|
|
||||||
exception: data.exception,
|
|
||||||
model: acc.model,
|
|
||||||
};
|
|
||||||
}, merged);
|
|
||||||
return merged;
|
|
||||||
}
|
|
||||||
case "google-palm": {
|
|
||||||
throw new Error("PaLM streaming not yet supported.");
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
assertNever(req.outboundApi);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Older Anthropic streaming format which sent full completion each time. */
|
|
||||||
function convertAnthropicV1(
|
|
||||||
events: string[],
|
|
||||||
req: Request
|
|
||||||
) {
|
|
||||||
const lastEvent = events[events.length - 2].toString();
|
|
||||||
const data = JSON.parse(
|
|
||||||
lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length)
|
|
||||||
);
|
|
||||||
const final: AnthropicCompletionResponse = { ...data, log_id: req.id };
|
|
||||||
return final;
|
|
||||||
}
|
|
||||||
|
|
|
@ -4,13 +4,16 @@ import * as http from "http";
|
||||||
import util from "util";
|
import util from "util";
|
||||||
import zlib from "zlib";
|
import zlib from "zlib";
|
||||||
import { logger } from "../../../logger";
|
import { logger } from "../../../logger";
|
||||||
|
import { enqueue, trackWaitTime } from "../../queue";
|
||||||
|
import { HttpError } from "../../../shared/errors";
|
||||||
import { keyPool } from "../../../shared/key-management";
|
import { keyPool } from "../../../shared/key-management";
|
||||||
import { getOpenAIModelFamily } from "../../../shared/models";
|
import { getOpenAIModelFamily } from "../../../shared/models";
|
||||||
import { enqueue, trackWaitTime } from "../../queue";
|
import { countTokens } from "../../../shared/tokenization";
|
||||||
import {
|
import {
|
||||||
incrementPromptCount,
|
incrementPromptCount,
|
||||||
incrementTokenCount,
|
incrementTokenCount,
|
||||||
} from "../../../shared/users/user-store";
|
} from "../../../shared/users/user-store";
|
||||||
|
import { assertNever } from "../../../shared/utils";
|
||||||
import {
|
import {
|
||||||
getCompletionFromBody,
|
getCompletionFromBody,
|
||||||
isCompletionRequest,
|
isCompletionRequest,
|
||||||
|
@ -18,8 +21,6 @@ import {
|
||||||
} from "../common";
|
} from "../common";
|
||||||
import { handleStreamedResponse } from "./handle-streamed-response";
|
import { handleStreamedResponse } from "./handle-streamed-response";
|
||||||
import { logPrompt } from "./log-prompt";
|
import { logPrompt } from "./log-prompt";
|
||||||
import { countTokens } from "../../../shared/tokenization";
|
|
||||||
import { assertNever } from "../../../shared/utils";
|
|
||||||
|
|
||||||
const DECODER_MAP = {
|
const DECODER_MAP = {
|
||||||
gzip: util.promisify(zlib.gunzip),
|
gzip: util.promisify(zlib.gunzip),
|
||||||
|
@ -83,7 +84,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
||||||
? handleStreamedResponse
|
? handleStreamedResponse
|
||||||
: decodeResponseBody;
|
: decodeResponseBody;
|
||||||
|
|
||||||
let lastMiddlewareName = initialHandler.name;
|
let lastMiddleware = initialHandler.name;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const body = await initialHandler(proxyRes, req, res);
|
const body = await initialHandler(proxyRes, req, res);
|
||||||
|
@ -112,37 +113,38 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const middleware of middlewareStack) {
|
for (const middleware of middlewareStack) {
|
||||||
lastMiddlewareName = middleware.name;
|
lastMiddleware = middleware.name;
|
||||||
await middleware(proxyRes, req, res, body);
|
await middleware(proxyRes, req, res, body);
|
||||||
}
|
}
|
||||||
|
|
||||||
trackWaitTime(req);
|
trackWaitTime(req);
|
||||||
} catch (error: any) {
|
} catch (error) {
|
||||||
// Hack: if the error is a retryable rate-limit error, the request has
|
// Hack: if the error is a retryable rate-limit error, the request has
|
||||||
// been re-enqueued and we can just return without doing anything else.
|
// been re-enqueued and we can just return without doing anything else.
|
||||||
if (error instanceof RetryableError) {
|
if (error instanceof RetryableError) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const errorData = {
|
// Already logged and responded to the client by handleUpstreamErrors
|
||||||
error: error.stack,
|
if (error instanceof HttpError) {
|
||||||
thrownBy: lastMiddlewareName,
|
if (!res.writableEnded) res.end();
|
||||||
key: req.key?.hash,
|
|
||||||
};
|
|
||||||
const message = `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`;
|
|
||||||
if (res.headersSent) {
|
|
||||||
req.log.error(errorData, message);
|
|
||||||
// This should have already been handled by the error handler, but
|
|
||||||
// just in case...
|
|
||||||
if (!res.writableEnded) {
|
|
||||||
res.end();
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
logger.error(errorData, message);
|
|
||||||
res
|
const { stack, message } = error;
|
||||||
.status(500)
|
const info = { stack, lastMiddleware, key: req.key?.hash };
|
||||||
.json({ error: "Internal server error", proxy_note: message });
|
const description = `Error while executing proxy response middleware: ${lastMiddleware} (${message})`;
|
||||||
|
|
||||||
|
if (res.headersSent) {
|
||||||
|
req.log.error(info, description);
|
||||||
|
if (!res.writableEnded) res.end();
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
req.log.error(info, description);
|
||||||
|
res
|
||||||
|
.status(500)
|
||||||
|
.json({ error: "Internal server error", proxy_note: description });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -203,7 +205,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
||||||
return resolve(body.toString());
|
return resolve(body.toString());
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
|
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
|
||||||
logger.warn({ error, key: req.key?.hash }, errorMessage);
|
logger.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
|
||||||
writeErrorResponse(req, res, 500, { error: errorMessage });
|
writeErrorResponse(req, res, 500, { error: errorMessage });
|
||||||
return reject(errorMessage);
|
return reject(errorMessage);
|
||||||
}
|
}
|
||||||
|
@ -223,7 +225,7 @@ type ProxiedErrorPayload = {
|
||||||
* an error to stop the middleware stack.
|
* an error to stop the middleware stack.
|
||||||
* On 429 errors, if request queueing is enabled, the request will be silently
|
* On 429 errors, if request queueing is enabled, the request will be silently
|
||||||
* re-enqueued. Otherwise, the request will be rejected with an error payload.
|
* re-enqueued. Otherwise, the request will be rejected with an error payload.
|
||||||
* @throws {Error} On HTTP error status code from upstream service
|
* @throws {HttpError} On HTTP error status code from upstream service
|
||||||
*/
|
*/
|
||||||
const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||||
proxyRes,
|
proxyRes,
|
||||||
|
@ -258,7 +260,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||||
proxy_note: `This is likely a temporary error with the upstream service.`,
|
proxy_note: `This is likely a temporary error with the upstream service.`,
|
||||||
};
|
};
|
||||||
writeErrorResponse(req, res, statusCode, errorObject);
|
writeErrorResponse(req, res, statusCode, errorObject);
|
||||||
throw new Error(parseError.message);
|
throw new HttpError(statusCode, parseError.message);
|
||||||
}
|
}
|
||||||
|
|
||||||
const errorType =
|
const errorType =
|
||||||
|
@ -371,7 +373,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||||
}
|
}
|
||||||
|
|
||||||
writeErrorResponse(req, res, statusCode, errorPayload);
|
writeErrorResponse(req, res, statusCode, errorPayload);
|
||||||
throw new Error(errorPayload.error?.message);
|
throw new HttpError(statusCode, errorPayload.error?.message);
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||||
|
|
||||||
|
export type AnthropicCompletionResponse = {
|
||||||
|
completion: string;
|
||||||
|
stop_reason: string;
|
||||||
|
truncated: boolean;
|
||||||
|
stop: any;
|
||||||
|
model: string;
|
||||||
|
log_id: string;
|
||||||
|
exception: null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||||
|
* finalized Anthropic completion response so that non-streaming middleware
|
||||||
|
* can operate on it as if it were a blocking response.
|
||||||
|
*/
|
||||||
|
export function mergeEventsForAnthropic(
|
||||||
|
events: OpenAIChatCompletionStreamEvent[]
|
||||||
|
): AnthropicCompletionResponse {
|
||||||
|
let merged: AnthropicCompletionResponse = {
|
||||||
|
log_id: "",
|
||||||
|
exception: null,
|
||||||
|
model: "",
|
||||||
|
completion: "",
|
||||||
|
stop_reason: "",
|
||||||
|
truncated: false,
|
||||||
|
stop: null,
|
||||||
|
};
|
||||||
|
merged = events.reduce((acc, event, i) => {
|
||||||
|
// The first event will only contain role assignment and response metadata
|
||||||
|
if (i === 0) {
|
||||||
|
acc.log_id = event.id;
|
||||||
|
acc.model = event.model;
|
||||||
|
acc.completion = "";
|
||||||
|
acc.stop_reason = "";
|
||||||
|
return acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
acc.stop_reason = event.choices[0].finish_reason ?? "";
|
||||||
|
if (event.choices[0].delta.content) {
|
||||||
|
acc.completion += event.choices[0].delta.content;
|
||||||
|
}
|
||||||
|
|
||||||
|
return acc;
|
||||||
|
}, merged);
|
||||||
|
return merged;
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||||
|
|
||||||
|
export type OpenAiChatCompletionResponse = {
|
||||||
|
id: string;
|
||||||
|
object: string;
|
||||||
|
created: number;
|
||||||
|
model: string;
|
||||||
|
choices: {
|
||||||
|
message: { role: string; content: string };
|
||||||
|
finish_reason: string | null;
|
||||||
|
index: number;
|
||||||
|
}[];
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||||
|
* finalized OpenAI chat completion response so that non-streaming middleware
|
||||||
|
* can operate on it as if it were a blocking response.
|
||||||
|
*/
|
||||||
|
export function mergeEventsForOpenAIChat(
|
||||||
|
events: OpenAIChatCompletionStreamEvent[]
|
||||||
|
): OpenAiChatCompletionResponse {
|
||||||
|
let merged: OpenAiChatCompletionResponse = {
|
||||||
|
id: "",
|
||||||
|
object: "",
|
||||||
|
created: 0,
|
||||||
|
model: "",
|
||||||
|
choices: [],
|
||||||
|
};
|
||||||
|
merged = events.reduce((acc, event, i) => {
|
||||||
|
// The first event will only contain role assignment and response metadata
|
||||||
|
if (i === 0) {
|
||||||
|
acc.id = event.id;
|
||||||
|
acc.object = event.object;
|
||||||
|
acc.created = event.created;
|
||||||
|
acc.model = event.model;
|
||||||
|
acc.choices = [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
message: {
|
||||||
|
role: event.choices[0].delta.role ?? "assistant",
|
||||||
|
content: "",
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
return acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
acc.choices[0].finish_reason = event.choices[0].finish_reason;
|
||||||
|
if (event.choices[0].delta.content) {
|
||||||
|
acc.choices[0].message.content += event.choices[0].delta.content;
|
||||||
|
}
|
||||||
|
|
||||||
|
return acc;
|
||||||
|
}, merged);
|
||||||
|
return merged;
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||||
|
|
||||||
|
export type OpenAiTextCompletionResponse = {
|
||||||
|
id: string;
|
||||||
|
object: string;
|
||||||
|
created: number;
|
||||||
|
model: string;
|
||||||
|
choices: {
|
||||||
|
text: string;
|
||||||
|
finish_reason: string | null;
|
||||||
|
index: number;
|
||||||
|
logprobs: null;
|
||||||
|
}[];
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||||
|
* finalized OpenAI text completion response so that non-streaming middleware
|
||||||
|
* can operate on it as if it were a blocking response.
|
||||||
|
*/
|
||||||
|
export function mergeEventsForOpenAIText(
|
||||||
|
events: OpenAIChatCompletionStreamEvent[]
|
||||||
|
): OpenAiTextCompletionResponse {
|
||||||
|
let merged: OpenAiTextCompletionResponse = {
|
||||||
|
id: "",
|
||||||
|
object: "",
|
||||||
|
created: 0,
|
||||||
|
model: "",
|
||||||
|
choices: [],
|
||||||
|
};
|
||||||
|
merged = events.reduce((acc, event, i) => {
|
||||||
|
// The first event will only contain role assignment and response metadata
|
||||||
|
if (i === 0) {
|
||||||
|
acc.id = event.id;
|
||||||
|
acc.object = event.object;
|
||||||
|
acc.created = event.created;
|
||||||
|
acc.model = event.model;
|
||||||
|
acc.choices = [
|
||||||
|
{
|
||||||
|
text: "",
|
||||||
|
index: 0,
|
||||||
|
finish_reason: null,
|
||||||
|
logprobs: null,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
return acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
acc.choices[0].finish_reason = event.choices[0].finish_reason;
|
||||||
|
if (event.choices[0].delta.content) {
|
||||||
|
acc.choices[0].text += event.choices[0].delta.content;
|
||||||
|
}
|
||||||
|
|
||||||
|
return acc;
|
||||||
|
}, merged);
|
||||||
|
return merged;
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
import { APIFormat } from "../../../../shared/key-management";
|
||||||
|
import { assertNever } from "../../../../shared/utils";
|
||||||
|
import {
|
||||||
|
mergeEventsForAnthropic,
|
||||||
|
mergeEventsForOpenAIChat,
|
||||||
|
mergeEventsForOpenAIText,
|
||||||
|
OpenAIChatCompletionStreamEvent
|
||||||
|
} from "./index";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collects SSE events containing incremental chat completion responses and
|
||||||
|
* compiles them into a single finalized response for downstream middleware.
|
||||||
|
*/
|
||||||
|
export class EventAggregator {
|
||||||
|
private readonly format: APIFormat;
|
||||||
|
private readonly events: OpenAIChatCompletionStreamEvent[];
|
||||||
|
|
||||||
|
constructor({ format }: { format: APIFormat }) {
|
||||||
|
this.events = [];
|
||||||
|
this.format = format;
|
||||||
|
}
|
||||||
|
|
||||||
|
addEvent(event: OpenAIChatCompletionStreamEvent) {
|
||||||
|
this.events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
getFinalResponse() {
|
||||||
|
switch (this.format) {
|
||||||
|
case "openai":
|
||||||
|
return mergeEventsForOpenAIChat(this.events);
|
||||||
|
case "openai-text":
|
||||||
|
return mergeEventsForOpenAIText(this.events);
|
||||||
|
case "anthropic":
|
||||||
|
return mergeEventsForAnthropic(this.events);
|
||||||
|
case "google-palm":
|
||||||
|
throw new Error("Google PaLM API does not support streaming responses");
|
||||||
|
default:
|
||||||
|
assertNever(this.format);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
export type SSEResponseTransformArgs = {
|
||||||
|
data: string;
|
||||||
|
lastPosition: number;
|
||||||
|
index: number;
|
||||||
|
fallbackId: string;
|
||||||
|
fallbackModel: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type OpenAIChatCompletionStreamEvent = {
|
||||||
|
id: string;
|
||||||
|
object: "chat.completion.chunk";
|
||||||
|
created: number;
|
||||||
|
model: string;
|
||||||
|
choices: {
|
||||||
|
index: number;
|
||||||
|
delta: { role?: string; content?: string };
|
||||||
|
finish_reason: string | null;
|
||||||
|
}[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export type StreamingCompletionTransformer = (
|
||||||
|
params: SSEResponseTransformArgs
|
||||||
|
) => { position: number; event?: OpenAIChatCompletionStreamEvent };
|
||||||
|
|
||||||
|
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
|
||||||
|
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
|
||||||
|
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
|
||||||
|
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
|
||||||
|
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
|
||||||
|
export { mergeEventsForAnthropic } from "./aggregators/anthropic";
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
export type ServerSentEvent = { id?: string; type?: string; data: string };
|
||||||
|
|
||||||
|
/** Given a string of SSE data, parse it into a `ServerSentEvent` object. */
|
||||||
|
export function parseEvent(event: string) {
|
||||||
|
const buffer: ServerSentEvent = { data: "" };
|
||||||
|
return event.split(/\r?\n/).reduce(parseLine, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseLine(event: ServerSentEvent, line: string) {
|
||||||
|
const separator = line.indexOf(":");
|
||||||
|
const field = separator === -1 ? line : line.slice(0,separator);
|
||||||
|
const value = separator === -1 ? "" : line.slice(separator + 1);
|
||||||
|
|
||||||
|
switch (field) {
|
||||||
|
case 'id':
|
||||||
|
event.id = value.trim()
|
||||||
|
break
|
||||||
|
case 'event':
|
||||||
|
event.type = value.trim()
|
||||||
|
break
|
||||||
|
case 'data':
|
||||||
|
event.data += value.trimStart()
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return event
|
||||||
|
}
|
|
@ -0,0 +1,123 @@
|
||||||
|
import { Transform, TransformOptions } from "stream";
|
||||||
|
import { logger } from "../../../../logger";
|
||||||
|
import { APIFormat } from "../../../../shared/key-management";
|
||||||
|
import { assertNever } from "../../../../shared/utils";
|
||||||
|
import {
|
||||||
|
anthropicV1ToOpenAI,
|
||||||
|
anthropicV2ToOpenAI,
|
||||||
|
OpenAIChatCompletionStreamEvent,
|
||||||
|
openAITextToOpenAIChat,
|
||||||
|
StreamingCompletionTransformer,
|
||||||
|
} from "./index";
|
||||||
|
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
|
||||||
|
|
||||||
|
const genlog = logger.child({ module: "sse-transformer" });
|
||||||
|
|
||||||
|
type SSEMessageTransformerOptions = TransformOptions & {
|
||||||
|
requestedModel: string;
|
||||||
|
requestId: string;
|
||||||
|
inputFormat: APIFormat;
|
||||||
|
inputApiVersion?: string;
|
||||||
|
logger?: typeof logger;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transforms SSE messages from one API format to OpenAI chat.completion.chunks.
|
||||||
|
* Emits the original string SSE message as an "originalMessage" event.
|
||||||
|
*/
|
||||||
|
export class SSEMessageTransformer extends Transform {
|
||||||
|
private lastPosition: number;
|
||||||
|
private msgCount: number;
|
||||||
|
private readonly transformFn: StreamingCompletionTransformer;
|
||||||
|
private readonly log;
|
||||||
|
private readonly fallbackId: string;
|
||||||
|
private readonly fallbackModel: string;
|
||||||
|
|
||||||
|
constructor(options: SSEMessageTransformerOptions) {
|
||||||
|
super({ ...options, readableObjectMode: true });
|
||||||
|
this.log = options.logger?.child({ module: "sse-transformer" }) ?? genlog;
|
||||||
|
this.lastPosition = 0;
|
||||||
|
this.msgCount = 0;
|
||||||
|
this.transformFn = getTransformer(
|
||||||
|
options.inputFormat,
|
||||||
|
options.inputApiVersion
|
||||||
|
);
|
||||||
|
this.fallbackId = options.requestId;
|
||||||
|
this.fallbackModel = options.requestedModel;
|
||||||
|
this.log.debug(
|
||||||
|
{
|
||||||
|
fn: this.transformFn.name,
|
||||||
|
format: options.inputFormat,
|
||||||
|
version: options.inputApiVersion,
|
||||||
|
},
|
||||||
|
"Selected SSE transformer"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
|
||||||
|
try {
|
||||||
|
const originalMessage = chunk.toString();
|
||||||
|
const { event: transformedMessage, position: newPosition } =
|
||||||
|
this.transformFn({
|
||||||
|
data: originalMessage,
|
||||||
|
lastPosition: this.lastPosition,
|
||||||
|
index: this.msgCount++,
|
||||||
|
fallbackId: this.fallbackId,
|
||||||
|
fallbackModel: this.fallbackModel,
|
||||||
|
});
|
||||||
|
this.lastPosition = newPosition;
|
||||||
|
|
||||||
|
this.emit("originalMessage", originalMessage);
|
||||||
|
|
||||||
|
// Some events may not be transformed, e.g. ping events
|
||||||
|
if (!transformedMessage) return callback();
|
||||||
|
|
||||||
|
if (this.msgCount === 1) {
|
||||||
|
this.push(createInitialMessage(transformedMessage));
|
||||||
|
}
|
||||||
|
this.push(transformedMessage);
|
||||||
|
callback();
|
||||||
|
} catch (err) {
|
||||||
|
this.log.error(err, "Error transforming SSE message");
|
||||||
|
callback(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTransformer(
|
||||||
|
responseApi: APIFormat,
|
||||||
|
version?: string
|
||||||
|
): StreamingCompletionTransformer {
|
||||||
|
switch (responseApi) {
|
||||||
|
case "openai":
|
||||||
|
return passthroughToOpenAI;
|
||||||
|
case "openai-text":
|
||||||
|
return openAITextToOpenAIChat;
|
||||||
|
case "anthropic":
|
||||||
|
return version === "2023-01-01"
|
||||||
|
? anthropicV1ToOpenAI
|
||||||
|
: anthropicV2ToOpenAI;
|
||||||
|
case "google-palm":
|
||||||
|
throw new Error("Google PaLM does not support streaming responses");
|
||||||
|
default:
|
||||||
|
assertNever(responseApi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenAI streaming chat completions start with an event that contains only the
|
||||||
|
* metadata and role (always 'assistant') for the response. To simulate this
|
||||||
|
* for APIs where the first event contains actual content, we create a fake
|
||||||
|
* initial event with no content but correct metadata.
|
||||||
|
*/
|
||||||
|
function createInitialMessage(
|
||||||
|
event: OpenAIChatCompletionStreamEvent
|
||||||
|
): OpenAIChatCompletionStreamEvent {
|
||||||
|
return {
|
||||||
|
...event,
|
||||||
|
choices: event.choices.map((choice) => ({
|
||||||
|
...choice,
|
||||||
|
delta: { role: "assistant", content: "" },
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
}
|
|
@ -1,11 +1,11 @@
|
||||||
import { Transform, TransformOptions } from "stream";
|
import { Transform, TransformOptions } from "stream";
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
import { Parser } from "lifion-aws-event-stream";
|
import { Parser } from "lifion-aws-event-stream";
|
||||||
import { logger } from "../../../logger";
|
import { logger } from "../../../../logger";
|
||||||
|
|
||||||
const log = logger.child({ module: "sse-stream-adapter" });
|
const log = logger.child({ module: "sse-stream-adapter" });
|
||||||
|
|
||||||
type SSEStreamAdapterOptions = TransformOptions & { isAwsStream?: boolean };
|
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string };
|
||||||
type AwsEventStreamMessage = {
|
type AwsEventStreamMessage = {
|
||||||
headers: { ":message-type": "event" | "exception" };
|
headers: { ":message-type": "event" | "exception" };
|
||||||
payload: { message?: string /** base64 encoded */; bytes?: string };
|
payload: { message?: string /** base64 encoded */; bytes?: string };
|
||||||
|
@ -15,24 +15,25 @@ type AwsEventStreamMessage = {
|
||||||
* Receives either text chunks or AWS binary event stream chunks and emits
|
* Receives either text chunks or AWS binary event stream chunks and emits
|
||||||
* full SSE events.
|
* full SSE events.
|
||||||
*/
|
*/
|
||||||
export class ServerSentEventStreamAdapter extends Transform {
|
export class SSEStreamAdapter extends Transform {
|
||||||
private readonly isAwsStream;
|
private readonly isAwsStream;
|
||||||
private parser = new Parser();
|
private parser = new Parser();
|
||||||
private partialMessage = "";
|
private partialMessage = "";
|
||||||
|
|
||||||
constructor(options?: SSEStreamAdapterOptions) {
|
constructor(options?: SSEStreamAdapterOptions) {
|
||||||
super(options);
|
super(options);
|
||||||
this.isAwsStream = options?.isAwsStream || false;
|
this.isAwsStream =
|
||||||
|
options?.contentType === "application/vnd.amazon.eventstream";
|
||||||
|
|
||||||
this.parser.on("data", (data: AwsEventStreamMessage) => {
|
this.parser.on("data", (data: AwsEventStreamMessage) => {
|
||||||
const message = this.processAwsEvent(data);
|
const message = this.processAwsEvent(data);
|
||||||
if (message) {
|
if (message) {
|
||||||
this.push(Buffer.from(message, "utf8"));
|
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
processAwsEvent(event: AwsEventStreamMessage): string | null {
|
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
|
||||||
const { payload, headers } = event;
|
const { payload, headers } = event;
|
||||||
if (headers[":message-type"] === "exception" || !payload.bytes) {
|
if (headers[":message-type"] === "exception" || !payload.bytes) {
|
||||||
log.error(
|
log.error(
|
||||||
|
@ -42,7 +43,14 @@ export class ServerSentEventStreamAdapter extends Transform {
|
||||||
const message = JSON.stringify(event);
|
const message = JSON.stringify(event);
|
||||||
return getFakeErrorCompletion("proxy AWS error", message);
|
return getFakeErrorCompletion("proxy AWS error", message);
|
||||||
} else {
|
} else {
|
||||||
return `data: ${Buffer.from(payload.bytes, "base64").toString("utf8")}`;
|
const { bytes } = payload;
|
||||||
|
// technically this is a transformation but we don't really distinguish
|
||||||
|
// between aws claude and anthropic claude at the APIFormat level, so
|
||||||
|
// these will short circuit the message transformer
|
||||||
|
return [
|
||||||
|
"event: completion",
|
||||||
|
`data: ${Buffer.from(bytes, "base64").toString("utf8")}`,
|
||||||
|
].join("\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,11 +63,15 @@ export class ServerSentEventStreamAdapter extends Transform {
|
||||||
// so we need to buffer and emit separate stream events for full
|
// so we need to buffer and emit separate stream events for full
|
||||||
// messages so we can parse/transform them properly.
|
// messages so we can parse/transform them properly.
|
||||||
const str = chunk.toString("utf8");
|
const str = chunk.toString("utf8");
|
||||||
|
|
||||||
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/);
|
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/);
|
||||||
this.partialMessage = fullMessages.pop() || "";
|
this.partialMessage = fullMessages.pop() || "";
|
||||||
|
|
||||||
for (const message of fullMessages) {
|
for (const message of fullMessages) {
|
||||||
this.push(message);
|
// Mixing line endings will break some clients and our request queue
|
||||||
|
// will have already sent \n for heartbeats, so we need to normalize
|
||||||
|
// to \n.
|
||||||
|
this.push(message.replace(/\r\n/g, "\n") + "\n\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
callback();
|
callback();
|
||||||
|
@ -72,7 +84,7 @@ export class ServerSentEventStreamAdapter extends Transform {
|
||||||
|
|
||||||
function getFakeErrorCompletion(type: string, message: string) {
|
function getFakeErrorCompletion(type: string, message: string) {
|
||||||
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
|
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
|
||||||
const fakeEvent = {
|
const fakeEvent = JSON.stringify({
|
||||||
log_id: "aws-proxy-sse-message",
|
log_id: "aws-proxy-sse-message",
|
||||||
stop_reason: type,
|
stop_reason: type,
|
||||||
completion:
|
completion:
|
||||||
|
@ -80,6 +92,6 @@ function getFakeErrorCompletion(type: string, message: string) {
|
||||||
truncated: false,
|
truncated: false,
|
||||||
stop: null,
|
stop: null,
|
||||||
model: "",
|
model: "",
|
||||||
};
|
});
|
||||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
return ["event: completion", `data: ${fakeEvent}\n\n`].join("\n");
|
||||||
}
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
import { StreamingCompletionTransformer } from "../index";
|
||||||
|
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||||
|
import { logger } from "../../../../../logger";
|
||||||
|
|
||||||
|
const log = logger.child({
|
||||||
|
module: "sse-transformer",
|
||||||
|
transformer: "anthropic-v1-to-openai",
|
||||||
|
});
|
||||||
|
|
||||||
|
type AnthropicV1StreamEvent = {
|
||||||
|
log_id?: string;
|
||||||
|
model?: string;
|
||||||
|
completion: string;
|
||||||
|
stop_reason: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transforms an incoming Anthropic SSE (2023-01-01 API) to an equivalent
|
||||||
|
* OpenAI chat.completion.chunk SSE.
|
||||||
|
*/
|
||||||
|
export const anthropicV1ToOpenAI: StreamingCompletionTransformer = (params) => {
|
||||||
|
const { data, lastPosition } = params;
|
||||||
|
|
||||||
|
const rawEvent = parseEvent(data);
|
||||||
|
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||||
|
return { position: lastPosition };
|
||||||
|
}
|
||||||
|
|
||||||
|
const completionEvent = asCompletion(rawEvent);
|
||||||
|
if (!completionEvent) {
|
||||||
|
return { position: lastPosition };
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anthropic sends the full completion so far with each event whereas OpenAI
|
||||||
|
// only sends the delta. To make the SSE events compatible, we remove
|
||||||
|
// everything before `lastPosition` from the completion.
|
||||||
|
const newEvent = {
|
||||||
|
id: "ant-" + (completionEvent.log_id ?? params.fallbackId),
|
||||||
|
object: "chat.completion.chunk" as const,
|
||||||
|
created: Date.now(),
|
||||||
|
model: completionEvent.model ?? params.fallbackModel,
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
delta: { content: completionEvent.completion?.slice(lastPosition) },
|
||||||
|
finish_reason: completionEvent.stop_reason,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
return { position: completionEvent.completion.length, event: newEvent };
|
||||||
|
};
|
||||||
|
|
||||||
|
function asCompletion(event: ServerSentEvent): AnthropicV1StreamEvent | null {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(event.data);
|
||||||
|
if (parsed.completion !== undefined && parsed.stop_reason !== undefined) {
|
||||||
|
return parsed;
|
||||||
|
} else {
|
||||||
|
// noinspection ExceptionCaughtLocallyJS
|
||||||
|
throw new Error("Missing required fields");
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log.warn({ error: error.stack, event }, "Received invalid event");
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
|
@ -0,0 +1,66 @@
|
||||||
|
import { StreamingCompletionTransformer } from "../index";
|
||||||
|
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||||
|
import { logger } from "../../../../../logger";
|
||||||
|
|
||||||
|
const log = logger.child({
|
||||||
|
module: "sse-transformer",
|
||||||
|
transformer: "anthropic-v2-to-openai",
|
||||||
|
});
|
||||||
|
|
||||||
|
type AnthropicV2StreamEvent = {
|
||||||
|
log_id?: string;
|
||||||
|
model?: string;
|
||||||
|
completion: string;
|
||||||
|
stop_reason: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transforms an incoming Anthropic SSE (2023-06-01 API) to an equivalent
|
||||||
|
* OpenAI chat.completion.chunk SSE.
|
||||||
|
*/
|
||||||
|
export const anthropicV2ToOpenAI: StreamingCompletionTransformer = (params) => {
|
||||||
|
const { data } = params;
|
||||||
|
|
||||||
|
const rawEvent = parseEvent(data);
|
||||||
|
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||||
|
return { position: -1 };
|
||||||
|
}
|
||||||
|
|
||||||
|
const completionEvent = asCompletion(rawEvent);
|
||||||
|
if (!completionEvent) {
|
||||||
|
return { position: -1 };
|
||||||
|
}
|
||||||
|
|
||||||
|
const newEvent = {
|
||||||
|
id: "ant-" + (completionEvent.log_id ?? params.fallbackId),
|
||||||
|
object: "chat.completion.chunk" as const,
|
||||||
|
created: Date.now(),
|
||||||
|
model: completionEvent.model ?? params.fallbackModel,
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
delta: { content: completionEvent.completion },
|
||||||
|
finish_reason: completionEvent.stop_reason,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
return { position: completionEvent.completion.length, event: newEvent };
|
||||||
|
};
|
||||||
|
|
||||||
|
function asCompletion(event: ServerSentEvent): AnthropicV2StreamEvent | null {
|
||||||
|
if (event.type === "ping") return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(event.data);
|
||||||
|
if (parsed.completion !== undefined && parsed.stop_reason !== undefined) {
|
||||||
|
return parsed;
|
||||||
|
} else {
|
||||||
|
// noinspection ExceptionCaughtLocallyJS
|
||||||
|
throw new Error("Missing required fields");
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log.warn({ error: error.stack, event }, "Received invalid event");
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
import { SSEResponseTransformArgs } from "../index";
|
||||||
|
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||||
|
import { logger } from "../../../../../logger";
|
||||||
|
|
||||||
|
const log = logger.child({
|
||||||
|
module: "sse-transformer",
|
||||||
|
transformer: "openai-text-to-openai",
|
||||||
|
});
|
||||||
|
|
||||||
|
type OpenAITextCompletionStreamEvent = {
|
||||||
|
id: string;
|
||||||
|
object: "text_completion";
|
||||||
|
created: number;
|
||||||
|
choices: {
|
||||||
|
text: string;
|
||||||
|
index: number;
|
||||||
|
logprobs: null;
|
||||||
|
finish_reason: string | null;
|
||||||
|
}[];
|
||||||
|
model: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const openAITextToOpenAIChat = (params: SSEResponseTransformArgs) => {
|
||||||
|
const { data } = params;
|
||||||
|
|
||||||
|
const rawEvent = parseEvent(data);
|
||||||
|
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||||
|
return { position: -1 };
|
||||||
|
}
|
||||||
|
|
||||||
|
const completionEvent = asCompletion(rawEvent);
|
||||||
|
if (!completionEvent) {
|
||||||
|
return { position: -1 };
|
||||||
|
}
|
||||||
|
|
||||||
|
const newEvent = {
|
||||||
|
id: completionEvent.id,
|
||||||
|
object: "chat.completion.chunk" as const,
|
||||||
|
created: completionEvent.created,
|
||||||
|
model: completionEvent.model,
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: completionEvent.choices[0].index,
|
||||||
|
delta: { content: completionEvent.choices[0].text },
|
||||||
|
finish_reason: completionEvent.choices[0].finish_reason,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
return { position: -1, event: newEvent };
|
||||||
|
};
|
||||||
|
|
||||||
|
function asCompletion(
|
||||||
|
event: ServerSentEvent
|
||||||
|
): OpenAITextCompletionStreamEvent | null {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(event.data);
|
||||||
|
if (Array.isArray(parsed.choices) && parsed.choices[0].text !== undefined) {
|
||||||
|
return parsed;
|
||||||
|
} else {
|
||||||
|
// noinspection ExceptionCaughtLocallyJS
|
||||||
|
throw new Error("Missing required fields");
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log.warn({ error: error.stack, event }, "Received invalid data event");
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
import {
|
||||||
|
OpenAIChatCompletionStreamEvent,
|
||||||
|
SSEResponseTransformArgs,
|
||||||
|
} from "../index";
|
||||||
|
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||||
|
import { logger } from "../../../../../logger";
|
||||||
|
|
||||||
|
const log = logger.child({
|
||||||
|
module: "sse-transformer",
|
||||||
|
transformer: "openai-to-openai",
|
||||||
|
});
|
||||||
|
|
||||||
|
export const passthroughToOpenAI = (params: SSEResponseTransformArgs) => {
|
||||||
|
const { data } = params;
|
||||||
|
|
||||||
|
const rawEvent = parseEvent(data);
|
||||||
|
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||||
|
return { position: -1 };
|
||||||
|
}
|
||||||
|
|
||||||
|
const completionEvent = asCompletion(rawEvent);
|
||||||
|
if (!completionEvent) {
|
||||||
|
return { position: -1 };
|
||||||
|
}
|
||||||
|
|
||||||
|
return { position: -1, event: completionEvent };
|
||||||
|
};
|
||||||
|
|
||||||
|
function asCompletion(
|
||||||
|
event: ServerSentEvent
|
||||||
|
): OpenAIChatCompletionStreamEvent | null {
|
||||||
|
try {
|
||||||
|
return JSON.parse(event.data);
|
||||||
|
} catch (error) {
|
||||||
|
log.warn({ error: error.stack, event }, "Received invalid event");
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
|
@ -23,10 +23,11 @@ import {
|
||||||
getOpenAIModelFamily,
|
getOpenAIModelFamily,
|
||||||
ModelFamily,
|
ModelFamily,
|
||||||
} from "../shared/models";
|
} from "../shared/models";
|
||||||
|
import { initializeSseStream } from "../shared/streaming";
|
||||||
|
import { assertNever } from "../shared/utils";
|
||||||
import { logger } from "../logger";
|
import { logger } from "../logger";
|
||||||
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
|
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
|
||||||
import { buildFakeSseMessage } from "./middleware/common";
|
import { buildFakeSseMessage } from "./middleware/common";
|
||||||
import { assertNever } from "../shared/utils";
|
|
||||||
|
|
||||||
const queue: Request[] = [];
|
const queue: Request[] = [];
|
||||||
const log = logger.child({ module: "request-queue" });
|
const log = logger.child({ module: "request-queue" });
|
||||||
|
@ -352,14 +353,8 @@ function killQueuedRequest(req: Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
function initStreaming(req: Request) {
|
function initStreaming(req: Request) {
|
||||||
req.log.info(`Initiating streaming for new queued request.`);
|
|
||||||
const res = req.res!;
|
const res = req.res!;
|
||||||
res.statusCode = 200;
|
initializeSseStream(res);
|
||||||
res.setHeader("Content-Type", "text/event-stream");
|
|
||||||
res.setHeader("Cache-Control", "no-cache");
|
|
||||||
res.setHeader("Connection", "keep-alive");
|
|
||||||
res.setHeader("X-Accel-Buffering", "no"); // nginx-specific fix
|
|
||||||
res.flushHeaders();
|
|
||||||
|
|
||||||
if (req.query.badSseParser) {
|
if (req.query.badSseParser) {
|
||||||
// Some clients have a broken SSE parser that doesn't handle comments
|
// Some clients have a broken SSE parser that doesn't handle comments
|
||||||
|
@ -368,7 +363,6 @@ function initStreaming(req: Request) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
res.write("\n");
|
|
||||||
res.write(": joining queue\n\n");
|
res.write(": joining queue\n\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,14 @@ import { googlePalm } from "./palm";
|
||||||
import { aws } from "./aws";
|
import { aws } from "./aws";
|
||||||
|
|
||||||
const proxyRouter = express.Router();
|
const proxyRouter = express.Router();
|
||||||
|
proxyRouter.use((req, _res, next) => {
|
||||||
|
if (req.headers.expect) {
|
||||||
|
// node-http-proxy does not like it when clients send `expect: 100-continue`
|
||||||
|
// and will stall. none of the upstream APIs use this header anyway.
|
||||||
|
delete req.headers.expect;
|
||||||
|
}
|
||||||
|
next();
|
||||||
|
});
|
||||||
proxyRouter.use(
|
proxyRouter.use(
|
||||||
express.json({ limit: "1536kb" }),
|
express.json({ limit: "1536kb" }),
|
||||||
express.urlencoded({ extended: true, limit: "1536kb" })
|
express.urlencoded({ extended: true, limit: "1536kb" })
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
import { Response } from "express";
|
||||||
|
import { IncomingMessage } from "http";
|
||||||
|
|
||||||
|
export function initializeSseStream(res: Response) {
|
||||||
|
res.statusCode = 200;
|
||||||
|
res.setHeader("Content-Type", "text/event-stream; charset=utf-8");
|
||||||
|
res.setHeader("Cache-Control", "no-cache");
|
||||||
|
res.setHeader("Connection", "keep-alive");
|
||||||
|
res.setHeader("X-Accel-Buffering", "no"); // nginx-specific fix
|
||||||
|
res.flushHeaders();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copies headers received from upstream API to the SSE response, excluding
|
||||||
|
* ones we need to set ourselves for SSE to work.
|
||||||
|
*/
|
||||||
|
export function copySseResponseHeaders(
|
||||||
|
proxyRes: IncomingMessage,
|
||||||
|
res: Response
|
||||||
|
) {
|
||||||
|
const toOmit = [
|
||||||
|
"content-length",
|
||||||
|
"content-encoding",
|
||||||
|
"transfer-encoding",
|
||||||
|
"content-type",
|
||||||
|
"connection",
|
||||||
|
"cache-control",
|
||||||
|
];
|
||||||
|
for (const [key, value] of Object.entries(proxyRes.headers)) {
|
||||||
|
if (!toOmit.includes(key) && value) {
|
||||||
|
res.setHeader(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -277,7 +277,7 @@ function cleanupExpiredTokens() {
|
||||||
deleted++;
|
deleted++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.debug({ disabled, deleted }, "Expired tokens cleaned up.");
|
log.trace({ disabled, deleted }, "Expired tokens cleaned up.");
|
||||||
}
|
}
|
||||||
|
|
||||||
function refreshAllQuotas() {
|
function refreshAllQuotas() {
|
||||||
|
|
Loading…
Reference in New Issue