Refactor handleStreamingResponse to make it less shit (khanon/oai-reverse-proxy!46)
This commit is contained in:
parent
6a3d753f0d
commit
ecf897e685
|
@ -42,7 +42,7 @@ export function writeErrorResponse(
|
|||
// the stream. Otherwise just send a normal error response.
|
||||
if (
|
||||
res.headersSent ||
|
||||
res.getHeader("content-type") === "text/event-stream"
|
||||
String(res.getHeader("content-type")).startsWith("text/event-stream")
|
||||
) {
|
||||
const errorContent =
|
||||
statusCode === 403
|
||||
|
|
|
@ -166,12 +166,7 @@ function openaiToAnthropic(req: Request) {
|
|||
throw result.error;
|
||||
}
|
||||
|
||||
// Anthropic has started versioning their API, indicated by an HTTP header
|
||||
// `anthropic-version`. The new June 2023 version is not backwards compatible
|
||||
// with our OpenAI-to-Anthropic transformations so we need to explicitly
|
||||
// request the older version for now. 2023-01-01 will be removed in September.
|
||||
// https://docs.anthropic.com/claude/reference/versioning
|
||||
req.headers["anthropic-version"] = "2023-01-01";
|
||||
req.headers["anthropic-version"] = "2023-06-01";
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = openAIMessagesToClaudePrompt(messages);
|
||||
|
|
|
@ -1,44 +1,16 @@
|
|||
import { Request, Response } from "express";
|
||||
import * as http from "http";
|
||||
import { pipeline } from "stream";
|
||||
import { promisify } from "util";
|
||||
import { buildFakeSseMessage } from "../common";
|
||||
import { RawResponseBodyHandler, decodeResponseBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { ServerSentEventStreamAdapter } from "./sse-stream-adapter";
|
||||
import { decodeResponseBody, RawResponseBodyHandler } from ".";
|
||||
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
|
||||
import { EventAggregator } from "./streaming/event-aggregator";
|
||||
import {
|
||||
copySseResponseHeaders,
|
||||
initializeSseStream,
|
||||
} from "../../../shared/streaming";
|
||||
|
||||
type OpenAiChatCompletionResponse = {
|
||||
id: string;
|
||||
object: string;
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
message: { role: string; content: string };
|
||||
finish_reason: string | null;
|
||||
index: number;
|
||||
}[];
|
||||
};
|
||||
|
||||
type OpenAiTextCompletionResponse = {
|
||||
id: string;
|
||||
object: string;
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
text: string;
|
||||
finish_reason: string | null;
|
||||
index: number;
|
||||
logprobs: null;
|
||||
}[];
|
||||
};
|
||||
|
||||
type AnthropicCompletionResponse = {
|
||||
completion: string;
|
||||
stop_reason: string;
|
||||
truncated: boolean;
|
||||
stop: any;
|
||||
model: string;
|
||||
log_id: string;
|
||||
exception: null;
|
||||
};
|
||||
const pipelineAsync = promisify(pipeline);
|
||||
|
||||
/**
|
||||
* Consume the SSE stream and forward events to the client. Once the stream is
|
||||
|
@ -49,370 +21,67 @@ type AnthropicCompletionResponse = {
|
|||
* in the event a streamed request results in a non-200 response, we need to
|
||||
* fall back to the non-streaming response handler so that the error handler
|
||||
* can inspect the error response.
|
||||
*
|
||||
* Currently most frontends don't support Anthropic streaming, so users can opt
|
||||
* to send requests for Claude models via an endpoint that accepts OpenAI-
|
||||
* compatible requests and translates the received Anthropic SSE events into
|
||||
* OpenAI ones, essentially pretending to be an OpenAI streaming API.
|
||||
*/
|
||||
export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
proxyRes,
|
||||
req,
|
||||
res
|
||||
) => {
|
||||
// If these differ, the user is using the OpenAI-compatibile endpoint, so
|
||||
// we need to translate the SSE events into OpenAI completion events for their
|
||||
// frontend.
|
||||
const { hash } = req.key!;
|
||||
if (!req.isStreaming) {
|
||||
const err = new Error(
|
||||
"handleStreamedResponse called for non-streaming request."
|
||||
);
|
||||
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
|
||||
throw err;
|
||||
throw new Error("handleStreamedResponse called for non-streaming request.");
|
||||
}
|
||||
|
||||
const key = req.key!;
|
||||
if (proxyRes.statusCode !== 200) {
|
||||
// Ensure we use the non-streaming middleware stack since we won't be
|
||||
// getting any events.
|
||||
req.isStreaming = false;
|
||||
if (proxyRes.statusCode! > 201) {
|
||||
req.isStreaming = false; // Forces non-streaming response handler to execute
|
||||
req.log.warn(
|
||||
{ statusCode: proxyRes.statusCode, key: key.hash },
|
||||
{ statusCode: proxyRes.statusCode, key: hash },
|
||||
`Streaming request returned error status code. Falling back to non-streaming response handler.`
|
||||
);
|
||||
return decodeResponseBody(proxyRes, req, res);
|
||||
}
|
||||
|
||||
req.log.debug(
|
||||
{ headers: proxyRes.headers, key: key.hash },
|
||||
`Received SSE headers.`
|
||||
{ headers: proxyRes.headers, key: hash },
|
||||
`Starting to proxy SSE stream.`
|
||||
);
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`);
|
||||
// Users waiting in the queue already have a SSE connection open for the
|
||||
// heartbeat, so we can't always send the stream headers.
|
||||
if (!res.headersSent) {
|
||||
copySseResponseHeaders(proxyRes, res);
|
||||
initializeSseStream(res);
|
||||
}
|
||||
|
||||
// Queued streaming requests will already have a connection open and headers
|
||||
// sent due to the heartbeat handler. In that case we can just start
|
||||
// streaming the response without sending headers.
|
||||
if (!res.headersSent) {
|
||||
res.setHeader("Content-Type", "text/event-stream");
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
res.setHeader("X-Accel-Buffering", "no");
|
||||
copyHeaders(proxyRes, res);
|
||||
res.flushHeaders();
|
||||
}
|
||||
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
||||
const contentType = proxyRes.headers["content-type"];
|
||||
|
||||
const adapter = new ServerSentEventStreamAdapter({
|
||||
isAwsStream:
|
||||
proxyRes.headers["content-type"] ===
|
||||
"application/vnd.amazon.eventstream",
|
||||
const adapter = new SSEStreamAdapter({ contentType });
|
||||
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||
const transformer = new SSEMessageTransformer({
|
||||
inputFormat: req.outboundApi, // outbound from the request's perspective
|
||||
inputApiVersion: String(req.headers["anthropic-version"]),
|
||||
logger: req.log,
|
||||
requestId: String(req.id),
|
||||
requestedModel: req.body.model,
|
||||
})
|
||||
.on("originalMessage", (msg: string) => {
|
||||
if (prefersNativeEvents) res.write(msg);
|
||||
})
|
||||
.on("data", (msg) => {
|
||||
if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`);
|
||||
aggregator.addEvent(msg);
|
||||
});
|
||||
|
||||
const events: string[] = [];
|
||||
let lastPosition = 0;
|
||||
let eventCount = 0;
|
||||
|
||||
proxyRes.pipe(adapter);
|
||||
|
||||
adapter.on("data", (chunk: any) => {
|
||||
try {
|
||||
const { event, position } = transformEvent({
|
||||
data: chunk.toString(),
|
||||
requestApi: req.inboundApi,
|
||||
responseApi: req.outboundApi,
|
||||
lastPosition,
|
||||
index: eventCount++,
|
||||
});
|
||||
events.push(event);
|
||||
lastPosition = position;
|
||||
res.write(event + "\n\n");
|
||||
} catch (err) {
|
||||
adapter.emit("error", err);
|
||||
}
|
||||
});
|
||||
|
||||
adapter.on("end", () => {
|
||||
try {
|
||||
req.log.info({ key: key.hash }, `Finished proxying SSE stream.`);
|
||||
const finalBody = convertEventsToFinalResponse(events, req);
|
||||
res.end();
|
||||
resolve(finalBody);
|
||||
} catch (err) {
|
||||
adapter.emit("error", err);
|
||||
}
|
||||
});
|
||||
|
||||
adapter.on("error", (err) => {
|
||||
req.log.error({ error: err, key: key.hash }, `Mid-stream error.`);
|
||||
const errorEvent = buildFakeSseMessage("stream-error", err.message, req);
|
||||
res.write(`data: ${JSON.stringify(errorEvent)}\n\ndata: [DONE]\n\n`);
|
||||
res.end();
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
try {
|
||||
await pipelineAsync(proxyRes, adapter, transformer);
|
||||
req.log.debug({ key: hash }, `Finished proxying SSE stream.`);
|
||||
res.end();
|
||||
return aggregator.getFinalResponse();
|
||||
} catch (err) {
|
||||
const errorEvent = buildFakeSseMessage("stream-error", err.message, req);
|
||||
res.write(`${errorEvent}data: [DONE]\n\n`);
|
||||
res.end();
|
||||
throw err;
|
||||
}
|
||||
};
|
||||
|
||||
type SSETransformationArgs = {
|
||||
data: string;
|
||||
requestApi: string;
|
||||
responseApi: string;
|
||||
lastPosition: number;
|
||||
index: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms SSE events from the given response API into events compatible with
|
||||
* the API requested by the client.
|
||||
*/
|
||||
function transformEvent(params: SSETransformationArgs) {
|
||||
const { data, requestApi, responseApi } = params;
|
||||
if (requestApi === responseApi) {
|
||||
return { position: -1, event: data };
|
||||
}
|
||||
|
||||
const trans = `${requestApi}->${responseApi}`;
|
||||
switch (trans) {
|
||||
case "openai->openai-text":
|
||||
return transformOpenAITextEventToOpenAIChat(params);
|
||||
case "openai->anthropic":
|
||||
// TODO: handle new anthropic streaming format
|
||||
return transformV1AnthropicEventToOpenAI(params);
|
||||
default:
|
||||
throw new Error(`Unsupported streaming API transformation. ${trans}`);
|
||||
}
|
||||
}
|
||||
|
||||
function transformOpenAITextEventToOpenAIChat(params: SSETransformationArgs) {
|
||||
const { data, index } = params;
|
||||
|
||||
if (!data.startsWith("data:")) return { position: -1, event: data };
|
||||
if (data.startsWith("data: [DONE]")) return { position: -1, event: data };
|
||||
|
||||
const event = JSON.parse(data.slice("data: ".length));
|
||||
|
||||
// The very first event must be a role assignment with no content.
|
||||
|
||||
const createEvent = () => ({
|
||||
id: event.id,
|
||||
object: "chat.completion.chunk",
|
||||
created: event.created,
|
||||
model: event.model,
|
||||
choices: [
|
||||
{
|
||||
message: { role: "", content: "" } as {
|
||||
role?: string;
|
||||
content: string;
|
||||
},
|
||||
index: 0,
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
let buffer = "";
|
||||
|
||||
if (index === 0) {
|
||||
const initialEvent = createEvent();
|
||||
initialEvent.choices[0].message.role = "assistant";
|
||||
buffer = `data: ${JSON.stringify(initialEvent)}\n\n`;
|
||||
}
|
||||
|
||||
const newEvent = {
|
||||
...event,
|
||||
choices: [
|
||||
{
|
||||
...event.choices[0],
|
||||
delta: { content: event.choices[0].text },
|
||||
text: undefined,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
buffer += `data: ${JSON.stringify(newEvent)}`;
|
||||
|
||||
return { position: -1, event: buffer };
|
||||
}
|
||||
|
||||
function transformV1AnthropicEventToOpenAI(params: SSETransformationArgs) {
|
||||
const { data, lastPosition } = params;
|
||||
// Anthropic sends the full completion so far with each event whereas OpenAI
|
||||
// only sends the delta. To make the SSE events compatible, we remove
|
||||
// everything before `lastPosition` from the completion.
|
||||
if (!data.startsWith("data:")) {
|
||||
return { position: lastPosition, event: data };
|
||||
}
|
||||
|
||||
if (data.startsWith("data: [DONE]")) {
|
||||
return { position: lastPosition, event: data };
|
||||
}
|
||||
|
||||
const event = JSON.parse(data.slice("data: ".length));
|
||||
const newEvent = {
|
||||
id: "ant-" + event.log_id,
|
||||
object: "chat.completion.chunk",
|
||||
created: Date.now(),
|
||||
model: event.model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { content: event.completion?.slice(lastPosition) },
|
||||
finish_reason: event.stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
return {
|
||||
position: event.completion.length,
|
||||
event: `data: ${JSON.stringify(newEvent)}`,
|
||||
};
|
||||
}
|
||||
|
||||
/** Copy headers, excluding ones we're already setting for the SSE response. */
|
||||
function copyHeaders(proxyRes: http.IncomingMessage, res: Response) {
|
||||
const toOmit = [
|
||||
"content-length",
|
||||
"content-encoding",
|
||||
"transfer-encoding",
|
||||
"content-type",
|
||||
"connection",
|
||||
"cache-control",
|
||||
];
|
||||
for (const [key, value] of Object.entries(proxyRes.headers)) {
|
||||
if (!toOmit.includes(key) && value) {
|
||||
res.setHeader(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts the list of incremental SSE events into an object that resembles a
|
||||
* full, non-streamed response from the API so that subsequent middleware can
|
||||
* operate on it as if it were a normal response.
|
||||
* Events are expected to be in the format they were received from the API.
|
||||
*/
|
||||
function convertEventsToFinalResponse(events: string[], req: Request) {
|
||||
switch (req.outboundApi) {
|
||||
case "openai": {
|
||||
let merged: OpenAiChatCompletionResponse = {
|
||||
id: "",
|
||||
object: "",
|
||||
created: 0,
|
||||
model: "",
|
||||
choices: [],
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
if (!event.startsWith("data: ")) return acc;
|
||||
if (event === "data: [DONE]") return acc;
|
||||
|
||||
const data = JSON.parse(event.slice("data: ".length));
|
||||
|
||||
// The first chat chunk only contains the role assignment and metadata
|
||||
if (i === 0) {
|
||||
return {
|
||||
id: data.id,
|
||||
object: data.object,
|
||||
created: data.created,
|
||||
model: data.model,
|
||||
choices: [
|
||||
{
|
||||
message: { role: data.choices[0].delta.role, content: "" },
|
||||
index: 0,
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
if (data.choices[0].delta.content) {
|
||||
acc.choices[0].message.content += data.choices[0].delta.content;
|
||||
}
|
||||
acc.choices[0].finish_reason = data.choices[0].finish_reason;
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
||||
case "openai-text": {
|
||||
let merged: OpenAiTextCompletionResponse = {
|
||||
id: "",
|
||||
object: "",
|
||||
created: 0,
|
||||
model: "",
|
||||
choices: [],
|
||||
// TODO: merge logprobs
|
||||
};
|
||||
merged = events.reduce((acc, event) => {
|
||||
if (!event.startsWith("data: ")) return acc;
|
||||
if (event === "data: [DONE]") return acc;
|
||||
|
||||
const data = JSON.parse(event.slice("data: ".length));
|
||||
|
||||
return {
|
||||
id: data.id,
|
||||
object: data.object,
|
||||
created: data.created,
|
||||
model: data.model,
|
||||
choices: [
|
||||
{
|
||||
text: acc.choices[0]?.text + data.choices[0].text,
|
||||
index: 0,
|
||||
finish_reason: data.choices[0].finish_reason,
|
||||
logprobs: null,
|
||||
},
|
||||
],
|
||||
};
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
||||
case "anthropic": {
|
||||
if (req.headers["anthropic-version"] === "2023-01-01") {
|
||||
return convertAnthropicV1(events, req);
|
||||
}
|
||||
|
||||
let merged: AnthropicCompletionResponse = {
|
||||
completion: "",
|
||||
stop_reason: "",
|
||||
truncated: false,
|
||||
stop: null,
|
||||
model: req.body.model,
|
||||
log_id: "",
|
||||
exception: null,
|
||||
}
|
||||
|
||||
merged = events.reduce((acc, event) => {
|
||||
if (!event.startsWith("data: ")) return acc;
|
||||
if (event === "data: [DONE]") return acc;
|
||||
|
||||
const data = JSON.parse(event.slice("data: ".length));
|
||||
|
||||
return {
|
||||
completion: acc.completion + data.completion,
|
||||
stop_reason: data.stop_reason,
|
||||
truncated: data.truncated,
|
||||
stop: data.stop,
|
||||
log_id: data.log_id,
|
||||
exception: data.exception,
|
||||
model: acc.model,
|
||||
};
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
||||
case "google-palm": {
|
||||
throw new Error("PaLM streaming not yet supported.");
|
||||
}
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
}
|
||||
|
||||
/** Older Anthropic streaming format which sent full completion each time. */
|
||||
function convertAnthropicV1(
|
||||
events: string[],
|
||||
req: Request
|
||||
) {
|
||||
const lastEvent = events[events.length - 2].toString();
|
||||
const data = JSON.parse(
|
||||
lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length)
|
||||
);
|
||||
const final: AnthropicCompletionResponse = { ...data, log_id: req.id };
|
||||
return final;
|
||||
}
|
||||
|
|
|
@ -4,13 +4,16 @@ import * as http from "http";
|
|||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import { logger } from "../../../logger";
|
||||
import { enqueue, trackWaitTime } from "../../queue";
|
||||
import { HttpError } from "../../../shared/errors";
|
||||
import { keyPool } from "../../../shared/key-management";
|
||||
import { getOpenAIModelFamily } from "../../../shared/models";
|
||||
import { enqueue, trackWaitTime } from "../../queue";
|
||||
import { countTokens } from "../../../shared/tokenization";
|
||||
import {
|
||||
incrementPromptCount,
|
||||
incrementTokenCount,
|
||||
} from "../../../shared/users/user-store";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import {
|
||||
getCompletionFromBody,
|
||||
isCompletionRequest,
|
||||
|
@ -18,8 +21,6 @@ import {
|
|||
} from "../common";
|
||||
import { handleStreamedResponse } from "./handle-streamed-response";
|
||||
import { logPrompt } from "./log-prompt";
|
||||
import { countTokens } from "../../../shared/tokenization";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
|
||||
const DECODER_MAP = {
|
||||
gzip: util.promisify(zlib.gunzip),
|
||||
|
@ -83,7 +84,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
|||
? handleStreamedResponse
|
||||
: decodeResponseBody;
|
||||
|
||||
let lastMiddlewareName = initialHandler.name;
|
||||
let lastMiddleware = initialHandler.name;
|
||||
|
||||
try {
|
||||
const body = await initialHandler(proxyRes, req, res);
|
||||
|
@ -112,37 +113,38 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
|||
}
|
||||
|
||||
for (const middleware of middlewareStack) {
|
||||
lastMiddlewareName = middleware.name;
|
||||
lastMiddleware = middleware.name;
|
||||
await middleware(proxyRes, req, res, body);
|
||||
}
|
||||
|
||||
trackWaitTime(req);
|
||||
} catch (error: any) {
|
||||
} catch (error) {
|
||||
// Hack: if the error is a retryable rate-limit error, the request has
|
||||
// been re-enqueued and we can just return without doing anything else.
|
||||
if (error instanceof RetryableError) {
|
||||
return;
|
||||
}
|
||||
|
||||
const errorData = {
|
||||
error: error.stack,
|
||||
thrownBy: lastMiddlewareName,
|
||||
key: req.key?.hash,
|
||||
};
|
||||
const message = `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`;
|
||||
if (res.headersSent) {
|
||||
req.log.error(errorData, message);
|
||||
// This should have already been handled by the error handler, but
|
||||
// just in case...
|
||||
if (!res.writableEnded) {
|
||||
res.end();
|
||||
}
|
||||
// Already logged and responded to the client by handleUpstreamErrors
|
||||
if (error instanceof HttpError) {
|
||||
if (!res.writableEnded) res.end();
|
||||
return;
|
||||
}
|
||||
logger.error(errorData, message);
|
||||
res
|
||||
.status(500)
|
||||
.json({ error: "Internal server error", proxy_note: message });
|
||||
|
||||
const { stack, message } = error;
|
||||
const info = { stack, lastMiddleware, key: req.key?.hash };
|
||||
const description = `Error while executing proxy response middleware: ${lastMiddleware} (${message})`;
|
||||
|
||||
if (res.headersSent) {
|
||||
req.log.error(info, description);
|
||||
if (!res.writableEnded) res.end();
|
||||
return;
|
||||
} else {
|
||||
req.log.error(info, description);
|
||||
res
|
||||
.status(500)
|
||||
.json({ error: "Internal server error", proxy_note: description });
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
@ -203,7 +205,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
|||
return resolve(body.toString());
|
||||
} catch (error: any) {
|
||||
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
|
||||
logger.warn({ error, key: req.key?.hash }, errorMessage);
|
||||
logger.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
|
||||
writeErrorResponse(req, res, 500, { error: errorMessage });
|
||||
return reject(errorMessage);
|
||||
}
|
||||
|
@ -223,7 +225,7 @@ type ProxiedErrorPayload = {
|
|||
* an error to stop the middleware stack.
|
||||
* On 429 errors, if request queueing is enabled, the request will be silently
|
||||
* re-enqueued. Otherwise, the request will be rejected with an error payload.
|
||||
* @throws {Error} On HTTP error status code from upstream service
|
||||
* @throws {HttpError} On HTTP error status code from upstream service
|
||||
*/
|
||||
const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
proxyRes,
|
||||
|
@ -258,7 +260,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
proxy_note: `This is likely a temporary error with the upstream service.`,
|
||||
};
|
||||
writeErrorResponse(req, res, statusCode, errorObject);
|
||||
throw new Error(parseError.message);
|
||||
throw new HttpError(statusCode, parseError.message);
|
||||
}
|
||||
|
||||
const errorType =
|
||||
|
@ -371,7 +373,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
}
|
||||
|
||||
writeErrorResponse(req, res, statusCode, errorPayload);
|
||||
throw new Error(errorPayload.error?.message);
|
||||
throw new HttpError(statusCode, errorPayload.error?.message);
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||
|
||||
export type AnthropicCompletionResponse = {
|
||||
completion: string;
|
||||
stop_reason: string;
|
||||
truncated: boolean;
|
||||
stop: any;
|
||||
model: string;
|
||||
log_id: string;
|
||||
exception: null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||
* finalized Anthropic completion response so that non-streaming middleware
|
||||
* can operate on it as if it were a blocking response.
|
||||
*/
|
||||
export function mergeEventsForAnthropic(
|
||||
events: OpenAIChatCompletionStreamEvent[]
|
||||
): AnthropicCompletionResponse {
|
||||
let merged: AnthropicCompletionResponse = {
|
||||
log_id: "",
|
||||
exception: null,
|
||||
model: "",
|
||||
completion: "",
|
||||
stop_reason: "",
|
||||
truncated: false,
|
||||
stop: null,
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
// The first event will only contain role assignment and response metadata
|
||||
if (i === 0) {
|
||||
acc.log_id = event.id;
|
||||
acc.model = event.model;
|
||||
acc.completion = "";
|
||||
acc.stop_reason = "";
|
||||
return acc;
|
||||
}
|
||||
|
||||
acc.stop_reason = event.choices[0].finish_reason ?? "";
|
||||
if (event.choices[0].delta.content) {
|
||||
acc.completion += event.choices[0].delta.content;
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||
|
||||
export type OpenAiChatCompletionResponse = {
|
||||
id: string;
|
||||
object: string;
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
message: { role: string; content: string };
|
||||
finish_reason: string | null;
|
||||
index: number;
|
||||
}[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||
* finalized OpenAI chat completion response so that non-streaming middleware
|
||||
* can operate on it as if it were a blocking response.
|
||||
*/
|
||||
export function mergeEventsForOpenAIChat(
|
||||
events: OpenAIChatCompletionStreamEvent[]
|
||||
): OpenAiChatCompletionResponse {
|
||||
let merged: OpenAiChatCompletionResponse = {
|
||||
id: "",
|
||||
object: "",
|
||||
created: 0,
|
||||
model: "",
|
||||
choices: [],
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
// The first event will only contain role assignment and response metadata
|
||||
if (i === 0) {
|
||||
acc.id = event.id;
|
||||
acc.object = event.object;
|
||||
acc.created = event.created;
|
||||
acc.model = event.model;
|
||||
acc.choices = [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: event.choices[0].delta.role ?? "assistant",
|
||||
content: "",
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
];
|
||||
return acc;
|
||||
}
|
||||
|
||||
acc.choices[0].finish_reason = event.choices[0].finish_reason;
|
||||
if (event.choices[0].delta.content) {
|
||||
acc.choices[0].message.content += event.choices[0].delta.content;
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||
|
||||
export type OpenAiTextCompletionResponse = {
|
||||
id: string;
|
||||
object: string;
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
text: string;
|
||||
finish_reason: string | null;
|
||||
index: number;
|
||||
logprobs: null;
|
||||
}[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||
* finalized OpenAI text completion response so that non-streaming middleware
|
||||
* can operate on it as if it were a blocking response.
|
||||
*/
|
||||
export function mergeEventsForOpenAIText(
|
||||
events: OpenAIChatCompletionStreamEvent[]
|
||||
): OpenAiTextCompletionResponse {
|
||||
let merged: OpenAiTextCompletionResponse = {
|
||||
id: "",
|
||||
object: "",
|
||||
created: 0,
|
||||
model: "",
|
||||
choices: [],
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
// The first event will only contain role assignment and response metadata
|
||||
if (i === 0) {
|
||||
acc.id = event.id;
|
||||
acc.object = event.object;
|
||||
acc.created = event.created;
|
||||
acc.model = event.model;
|
||||
acc.choices = [
|
||||
{
|
||||
text: "",
|
||||
index: 0,
|
||||
finish_reason: null,
|
||||
logprobs: null,
|
||||
},
|
||||
];
|
||||
return acc;
|
||||
}
|
||||
|
||||
acc.choices[0].finish_reason = event.choices[0].finish_reason;
|
||||
if (event.choices[0].delta.content) {
|
||||
acc.choices[0].text += event.choices[0].delta.content;
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import {
|
||||
mergeEventsForAnthropic,
|
||||
mergeEventsForOpenAIChat,
|
||||
mergeEventsForOpenAIText,
|
||||
OpenAIChatCompletionStreamEvent
|
||||
} from "./index";
|
||||
|
||||
/**
|
||||
* Collects SSE events containing incremental chat completion responses and
|
||||
* compiles them into a single finalized response for downstream middleware.
|
||||
*/
|
||||
export class EventAggregator {
|
||||
private readonly format: APIFormat;
|
||||
private readonly events: OpenAIChatCompletionStreamEvent[];
|
||||
|
||||
constructor({ format }: { format: APIFormat }) {
|
||||
this.events = [];
|
||||
this.format = format;
|
||||
}
|
||||
|
||||
addEvent(event: OpenAIChatCompletionStreamEvent) {
|
||||
this.events.push(event);
|
||||
}
|
||||
|
||||
getFinalResponse() {
|
||||
switch (this.format) {
|
||||
case "openai":
|
||||
return mergeEventsForOpenAIChat(this.events);
|
||||
case "openai-text":
|
||||
return mergeEventsForOpenAIText(this.events);
|
||||
case "anthropic":
|
||||
return mergeEventsForAnthropic(this.events);
|
||||
case "google-palm":
|
||||
throw new Error("Google PaLM API does not support streaming responses");
|
||||
default:
|
||||
assertNever(this.format);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
export type SSEResponseTransformArgs = {
|
||||
data: string;
|
||||
lastPosition: number;
|
||||
index: number;
|
||||
fallbackId: string;
|
||||
fallbackModel: string;
|
||||
};
|
||||
|
||||
export type OpenAIChatCompletionStreamEvent = {
|
||||
id: string;
|
||||
object: "chat.completion.chunk";
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
index: number;
|
||||
delta: { role?: string; content?: string };
|
||||
finish_reason: string | null;
|
||||
}[];
|
||||
}
|
||||
|
||||
export type StreamingCompletionTransformer = (
|
||||
params: SSEResponseTransformArgs
|
||||
) => { position: number; event?: OpenAIChatCompletionStreamEvent };
|
||||
|
||||
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
|
||||
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
|
||||
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
|
||||
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
|
||||
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
|
||||
export { mergeEventsForAnthropic } from "./aggregators/anthropic";
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
export type ServerSentEvent = { id?: string; type?: string; data: string };
|
||||
|
||||
/** Given a string of SSE data, parse it into a `ServerSentEvent` object. */
|
||||
export function parseEvent(event: string) {
|
||||
const buffer: ServerSentEvent = { data: "" };
|
||||
return event.split(/\r?\n/).reduce(parseLine, buffer)
|
||||
}
|
||||
|
||||
function parseLine(event: ServerSentEvent, line: string) {
|
||||
const separator = line.indexOf(":");
|
||||
const field = separator === -1 ? line : line.slice(0,separator);
|
||||
const value = separator === -1 ? "" : line.slice(separator + 1);
|
||||
|
||||
switch (field) {
|
||||
case 'id':
|
||||
event.id = value.trim()
|
||||
break
|
||||
case 'event':
|
||||
event.type = value.trim()
|
||||
break
|
||||
case 'data':
|
||||
event.data += value.trimStart()
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
return event
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
import { Transform, TransformOptions } from "stream";
|
||||
import { logger } from "../../../../logger";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import {
|
||||
anthropicV1ToOpenAI,
|
||||
anthropicV2ToOpenAI,
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
openAITextToOpenAIChat,
|
||||
StreamingCompletionTransformer,
|
||||
} from "./index";
|
||||
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
|
||||
|
||||
const genlog = logger.child({ module: "sse-transformer" });
|
||||
|
||||
type SSEMessageTransformerOptions = TransformOptions & {
|
||||
requestedModel: string;
|
||||
requestId: string;
|
||||
inputFormat: APIFormat;
|
||||
inputApiVersion?: string;
|
||||
logger?: typeof logger;
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms SSE messages from one API format to OpenAI chat.completion.chunks.
|
||||
* Emits the original string SSE message as an "originalMessage" event.
|
||||
*/
|
||||
export class SSEMessageTransformer extends Transform {
|
||||
private lastPosition: number;
|
||||
private msgCount: number;
|
||||
private readonly transformFn: StreamingCompletionTransformer;
|
||||
private readonly log;
|
||||
private readonly fallbackId: string;
|
||||
private readonly fallbackModel: string;
|
||||
|
||||
constructor(options: SSEMessageTransformerOptions) {
|
||||
super({ ...options, readableObjectMode: true });
|
||||
this.log = options.logger?.child({ module: "sse-transformer" }) ?? genlog;
|
||||
this.lastPosition = 0;
|
||||
this.msgCount = 0;
|
||||
this.transformFn = getTransformer(
|
||||
options.inputFormat,
|
||||
options.inputApiVersion
|
||||
);
|
||||
this.fallbackId = options.requestId;
|
||||
this.fallbackModel = options.requestedModel;
|
||||
this.log.debug(
|
||||
{
|
||||
fn: this.transformFn.name,
|
||||
format: options.inputFormat,
|
||||
version: options.inputApiVersion,
|
||||
},
|
||||
"Selected SSE transformer"
|
||||
);
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
|
||||
try {
|
||||
const originalMessage = chunk.toString();
|
||||
const { event: transformedMessage, position: newPosition } =
|
||||
this.transformFn({
|
||||
data: originalMessage,
|
||||
lastPosition: this.lastPosition,
|
||||
index: this.msgCount++,
|
||||
fallbackId: this.fallbackId,
|
||||
fallbackModel: this.fallbackModel,
|
||||
});
|
||||
this.lastPosition = newPosition;
|
||||
|
||||
this.emit("originalMessage", originalMessage);
|
||||
|
||||
// Some events may not be transformed, e.g. ping events
|
||||
if (!transformedMessage) return callback();
|
||||
|
||||
if (this.msgCount === 1) {
|
||||
this.push(createInitialMessage(transformedMessage));
|
||||
}
|
||||
this.push(transformedMessage);
|
||||
callback();
|
||||
} catch (err) {
|
||||
this.log.error(err, "Error transforming SSE message");
|
||||
callback(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getTransformer(
|
||||
responseApi: APIFormat,
|
||||
version?: string
|
||||
): StreamingCompletionTransformer {
|
||||
switch (responseApi) {
|
||||
case "openai":
|
||||
return passthroughToOpenAI;
|
||||
case "openai-text":
|
||||
return openAITextToOpenAIChat;
|
||||
case "anthropic":
|
||||
return version === "2023-01-01"
|
||||
? anthropicV1ToOpenAI
|
||||
: anthropicV2ToOpenAI;
|
||||
case "google-palm":
|
||||
throw new Error("Google PaLM does not support streaming responses");
|
||||
default:
|
||||
assertNever(responseApi);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI streaming chat completions start with an event that contains only the
|
||||
* metadata and role (always 'assistant') for the response. To simulate this
|
||||
* for APIs where the first event contains actual content, we create a fake
|
||||
* initial event with no content but correct metadata.
|
||||
*/
|
||||
function createInitialMessage(
|
||||
event: OpenAIChatCompletionStreamEvent
|
||||
): OpenAIChatCompletionStreamEvent {
|
||||
return {
|
||||
...event,
|
||||
choices: event.choices.map((choice) => ({
|
||||
...choice,
|
||||
delta: { role: "assistant", content: "" },
|
||||
})),
|
||||
};
|
||||
}
|
|
@ -1,11 +1,11 @@
|
|||
import { Transform, TransformOptions } from "stream";
|
||||
// @ts-ignore
|
||||
import { Parser } from "lifion-aws-event-stream";
|
||||
import { logger } from "../../../logger";
|
||||
import { logger } from "../../../../logger";
|
||||
|
||||
const log = logger.child({ module: "sse-stream-adapter" });
|
||||
|
||||
type SSEStreamAdapterOptions = TransformOptions & { isAwsStream?: boolean };
|
||||
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string };
|
||||
type AwsEventStreamMessage = {
|
||||
headers: { ":message-type": "event" | "exception" };
|
||||
payload: { message?: string /** base64 encoded */; bytes?: string };
|
||||
|
@ -15,24 +15,25 @@ type AwsEventStreamMessage = {
|
|||
* Receives either text chunks or AWS binary event stream chunks and emits
|
||||
* full SSE events.
|
||||
*/
|
||||
export class ServerSentEventStreamAdapter extends Transform {
|
||||
export class SSEStreamAdapter extends Transform {
|
||||
private readonly isAwsStream;
|
||||
private parser = new Parser();
|
||||
private partialMessage = "";
|
||||
|
||||
constructor(options?: SSEStreamAdapterOptions) {
|
||||
super(options);
|
||||
this.isAwsStream = options?.isAwsStream || false;
|
||||
this.isAwsStream =
|
||||
options?.contentType === "application/vnd.amazon.eventstream";
|
||||
|
||||
this.parser.on("data", (data: AwsEventStreamMessage) => {
|
||||
const message = this.processAwsEvent(data);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message, "utf8"));
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
processAwsEvent(event: AwsEventStreamMessage): string | null {
|
||||
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
|
||||
const { payload, headers } = event;
|
||||
if (headers[":message-type"] === "exception" || !payload.bytes) {
|
||||
log.error(
|
||||
|
@ -42,7 +43,14 @@ export class ServerSentEventStreamAdapter extends Transform {
|
|||
const message = JSON.stringify(event);
|
||||
return getFakeErrorCompletion("proxy AWS error", message);
|
||||
} else {
|
||||
return `data: ${Buffer.from(payload.bytes, "base64").toString("utf8")}`;
|
||||
const { bytes } = payload;
|
||||
// technically this is a transformation but we don't really distinguish
|
||||
// between aws claude and anthropic claude at the APIFormat level, so
|
||||
// these will short circuit the message transformer
|
||||
return [
|
||||
"event: completion",
|
||||
`data: ${Buffer.from(bytes, "base64").toString("utf8")}`,
|
||||
].join("\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,11 +63,15 @@ export class ServerSentEventStreamAdapter extends Transform {
|
|||
// so we need to buffer and emit separate stream events for full
|
||||
// messages so we can parse/transform them properly.
|
||||
const str = chunk.toString("utf8");
|
||||
|
||||
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/);
|
||||
this.partialMessage = fullMessages.pop() || "";
|
||||
|
||||
for (const message of fullMessages) {
|
||||
this.push(message);
|
||||
// Mixing line endings will break some clients and our request queue
|
||||
// will have already sent \n for heartbeats, so we need to normalize
|
||||
// to \n.
|
||||
this.push(message.replace(/\r\n/g, "\n") + "\n\n");
|
||||
}
|
||||
}
|
||||
callback();
|
||||
|
@ -72,7 +84,7 @@ export class ServerSentEventStreamAdapter extends Transform {
|
|||
|
||||
function getFakeErrorCompletion(type: string, message: string) {
|
||||
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
|
||||
const fakeEvent = {
|
||||
const fakeEvent = JSON.stringify({
|
||||
log_id: "aws-proxy-sse-message",
|
||||
stop_reason: type,
|
||||
completion:
|
||||
|
@ -80,6 +92,6 @@ function getFakeErrorCompletion(type: string, message: string) {
|
|||
truncated: false,
|
||||
stop: null,
|
||||
model: "",
|
||||
};
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
});
|
||||
return ["event: completion", `data: ${fakeEvent}\n\n`].join("\n");
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
import { StreamingCompletionTransformer } from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
import { logger } from "../../../../../logger";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "anthropic-v1-to-openai",
|
||||
});
|
||||
|
||||
type AnthropicV1StreamEvent = {
|
||||
log_id?: string;
|
||||
model?: string;
|
||||
completion: string;
|
||||
stop_reason: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms an incoming Anthropic SSE (2023-01-01 API) to an equivalent
|
||||
* OpenAI chat.completion.chunk SSE.
|
||||
*/
|
||||
export const anthropicV1ToOpenAI: StreamingCompletionTransformer = (params) => {
|
||||
const { data, lastPosition } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||
return { position: lastPosition };
|
||||
}
|
||||
|
||||
const completionEvent = asCompletion(rawEvent);
|
||||
if (!completionEvent) {
|
||||
return { position: lastPosition };
|
||||
}
|
||||
|
||||
// Anthropic sends the full completion so far with each event whereas OpenAI
|
||||
// only sends the delta. To make the SSE events compatible, we remove
|
||||
// everything before `lastPosition` from the completion.
|
||||
const newEvent = {
|
||||
id: "ant-" + (completionEvent.log_id ?? params.fallbackId),
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: completionEvent.model ?? params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { content: completionEvent.completion?.slice(lastPosition) },
|
||||
finish_reason: completionEvent.stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
return { position: completionEvent.completion.length, event: newEvent };
|
||||
};
|
||||
|
||||
function asCompletion(event: ServerSentEvent): AnthropicV1StreamEvent | null {
|
||||
try {
|
||||
const parsed = JSON.parse(event.data);
|
||||
if (parsed.completion !== undefined && parsed.stop_reason !== undefined) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid event");
|
||||
}
|
||||
return null;
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
import { StreamingCompletionTransformer } from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
import { logger } from "../../../../../logger";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "anthropic-v2-to-openai",
|
||||
});
|
||||
|
||||
type AnthropicV2StreamEvent = {
|
||||
log_id?: string;
|
||||
model?: string;
|
||||
completion: string;
|
||||
stop_reason: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms an incoming Anthropic SSE (2023-06-01 API) to an equivalent
|
||||
* OpenAI chat.completion.chunk SSE.
|
||||
*/
|
||||
export const anthropicV2ToOpenAI: StreamingCompletionTransformer = (params) => {
|
||||
const { data } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const completionEvent = asCompletion(rawEvent);
|
||||
if (!completionEvent) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const newEvent = {
|
||||
id: "ant-" + (completionEvent.log_id ?? params.fallbackId),
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: completionEvent.model ?? params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { content: completionEvent.completion },
|
||||
finish_reason: completionEvent.stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
return { position: completionEvent.completion.length, event: newEvent };
|
||||
};
|
||||
|
||||
function asCompletion(event: ServerSentEvent): AnthropicV2StreamEvent | null {
|
||||
if (event.type === "ping") return null;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(event.data);
|
||||
if (parsed.completion !== undefined && parsed.stop_reason !== undefined) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid event");
|
||||
}
|
||||
return null;
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
import { SSEResponseTransformArgs } from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
import { logger } from "../../../../../logger";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "openai-text-to-openai",
|
||||
});
|
||||
|
||||
type OpenAITextCompletionStreamEvent = {
|
||||
id: string;
|
||||
object: "text_completion";
|
||||
created: number;
|
||||
choices: {
|
||||
text: string;
|
||||
index: number;
|
||||
logprobs: null;
|
||||
finish_reason: string | null;
|
||||
}[];
|
||||
model: string;
|
||||
};
|
||||
|
||||
export const openAITextToOpenAIChat = (params: SSEResponseTransformArgs) => {
|
||||
const { data } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const completionEvent = asCompletion(rawEvent);
|
||||
if (!completionEvent) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const newEvent = {
|
||||
id: completionEvent.id,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: completionEvent.created,
|
||||
model: completionEvent.model,
|
||||
choices: [
|
||||
{
|
||||
index: completionEvent.choices[0].index,
|
||||
delta: { content: completionEvent.choices[0].text },
|
||||
finish_reason: completionEvent.choices[0].finish_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
return { position: -1, event: newEvent };
|
||||
};
|
||||
|
||||
function asCompletion(
|
||||
event: ServerSentEvent
|
||||
): OpenAITextCompletionStreamEvent | null {
|
||||
try {
|
||||
const parsed = JSON.parse(event.data);
|
||||
if (Array.isArray(parsed.choices) && parsed.choices[0].text !== undefined) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid data event");
|
||||
}
|
||||
return null;
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
import {
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
SSEResponseTransformArgs,
|
||||
} from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
import { logger } from "../../../../../logger";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "openai-to-openai",
|
||||
});
|
||||
|
||||
export const passthroughToOpenAI = (params: SSEResponseTransformArgs) => {
|
||||
const { data } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const completionEvent = asCompletion(rawEvent);
|
||||
if (!completionEvent) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
return { position: -1, event: completionEvent };
|
||||
};
|
||||
|
||||
function asCompletion(
|
||||
event: ServerSentEvent
|
||||
): OpenAIChatCompletionStreamEvent | null {
|
||||
try {
|
||||
return JSON.parse(event.data);
|
||||
} catch (error) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid event");
|
||||
}
|
||||
return null;
|
||||
}
|
|
@ -23,10 +23,11 @@ import {
|
|||
getOpenAIModelFamily,
|
||||
ModelFamily,
|
||||
} from "../shared/models";
|
||||
import { initializeSseStream } from "../shared/streaming";
|
||||
import { assertNever } from "../shared/utils";
|
||||
import { logger } from "../logger";
|
||||
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
|
||||
import { buildFakeSseMessage } from "./middleware/common";
|
||||
import { assertNever } from "../shared/utils";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
|
@ -352,14 +353,8 @@ function killQueuedRequest(req: Request) {
|
|||
}
|
||||
|
||||
function initStreaming(req: Request) {
|
||||
req.log.info(`Initiating streaming for new queued request.`);
|
||||
const res = req.res!;
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Content-Type", "text/event-stream");
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
res.setHeader("X-Accel-Buffering", "no"); // nginx-specific fix
|
||||
res.flushHeaders();
|
||||
initializeSseStream(res);
|
||||
|
||||
if (req.query.badSseParser) {
|
||||
// Some clients have a broken SSE parser that doesn't handle comments
|
||||
|
@ -368,7 +363,6 @@ function initStreaming(req: Request) {
|
|||
return;
|
||||
}
|
||||
|
||||
res.write("\n");
|
||||
res.write(": joining queue\n\n");
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,14 @@ import { googlePalm } from "./palm";
|
|||
import { aws } from "./aws";
|
||||
|
||||
const proxyRouter = express.Router();
|
||||
proxyRouter.use((req, _res, next) => {
|
||||
if (req.headers.expect) {
|
||||
// node-http-proxy does not like it when clients send `expect: 100-continue`
|
||||
// and will stall. none of the upstream APIs use this header anyway.
|
||||
delete req.headers.expect;
|
||||
}
|
||||
next();
|
||||
});
|
||||
proxyRouter.use(
|
||||
express.json({ limit: "1536kb" }),
|
||||
express.urlencoded({ extended: true, limit: "1536kb" })
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import { Response } from "express";
|
||||
import { IncomingMessage } from "http";
|
||||
|
||||
export function initializeSseStream(res: Response) {
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Content-Type", "text/event-stream; charset=utf-8");
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
res.setHeader("X-Accel-Buffering", "no"); // nginx-specific fix
|
||||
res.flushHeaders();
|
||||
}
|
||||
|
||||
/**
|
||||
* Copies headers received from upstream API to the SSE response, excluding
|
||||
* ones we need to set ourselves for SSE to work.
|
||||
*/
|
||||
export function copySseResponseHeaders(
|
||||
proxyRes: IncomingMessage,
|
||||
res: Response
|
||||
) {
|
||||
const toOmit = [
|
||||
"content-length",
|
||||
"content-encoding",
|
||||
"transfer-encoding",
|
||||
"content-type",
|
||||
"connection",
|
||||
"cache-control",
|
||||
];
|
||||
for (const [key, value] of Object.entries(proxyRes.headers)) {
|
||||
if (!toOmit.includes(key) && value) {
|
||||
res.setHeader(key, value);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -277,7 +277,7 @@ function cleanupExpiredTokens() {
|
|||
deleted++;
|
||||
}
|
||||
}
|
||||
log.debug({ disabled, deleted }, "Expired tokens cleaned up.");
|
||||
log.trace({ disabled, deleted }, "Expired tokens cleaned up.");
|
||||
}
|
||||
|
||||
function refreshAllQuotas() {
|
||||
|
|
Loading…
Reference in New Issue