fixes for gemini api streaming

This commit is contained in:
nai-degen 2024-09-29 12:44:18 -05:00
parent cfb6353c65
commit 22d7f966c6
7 changed files with 79 additions and 146 deletions

View File

@ -149,7 +149,7 @@ function setStreamFlag(req: Request) {
}
/**
* Replaces requests for non-Google AI models with gemini-pro-1.5-latest.
* Replaces requests for non-Google AI models with gemini-1.5-pro-latest.
* Also strips models/ from the beginning of the model IDs.
**/
function maybeReassignModel(req: Request) {
@ -169,8 +169,8 @@ function maybeReassignModel(req: Request) {
return;
}
req.log.info({ requested }, "Reassigning model to gemini-pro-1.5-latest");
req.body.model = "gemini-pro-1.5-latest";
req.log.info({ requested }, "Reassigning model to gemini-1.5-pro-latest");
req.body.model = "gemini-1.5-pro-latest";
}
export const googleAI = googleAIRouter;

View File

@ -1,17 +1,17 @@
import { keyPool } from "../../../../shared/key-management";
import { ProxyReqMutator} from "../index";
import { ProxyReqMutator } from "../index";
export const addGoogleAIKey: ProxyReqMutator = (manager) => {
const req = manager.request;
const inboundValid =
req.inboundApi === "openai" || req.inboundApi === "google-ai";
const outboundValid = req.outboundApi === "google-ai";
const serviceValid = req.service === "google-ai";
if (!inboundValid || !outboundValid || !serviceValid) {
throw new Error("addGoogleAIKey called on invalid request");
}
const model = req.body.model;
const key = keyPool.get(model, "google-ai");
manager.setKey(key);
@ -20,7 +20,7 @@ export const addGoogleAIKey: ProxyReqMutator = (manager) => {
{ key: key.hash, model, stream: req.isStreaming },
"Assigned Google AI API key to request"
);
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
const payload = { ...req.body, stream: undefined, model: undefined };
@ -33,8 +33,8 @@ export const addGoogleAIKey: ProxyReqMutator = (manager) => {
protocol: "https:",
hostname: "generativelanguage.googleapis.com",
path: `/v1beta/models/${model}:${
req.isStreaming ? "streamGenerateContent" : "generateContent"
}?key=${key.key}`,
req.isStreaming ? "streamGenerateContent?alt=sse&" : "generateContent?"
}key=${key.key}`,
headers: {
["host"]: `generativelanguage.googleapis.com`,
["content-type"]: "application/json",

View File

@ -100,23 +100,30 @@ export function createQueuedProxyMiddleware({
type ProxiedResponse = http.IncomingMessage & Response & any;
function pinoLoggerPlugin(proxyServer: ProxyServer<Request>) {
proxyServer.on("error", (err, req, res, target) => {
const originalUrl = req.originalUrl;
const targetUrl = target?.toString();
req.log.error(
{ originalUrl, targetUrl, err },
{ originalUrl: req.originalUrl, targetUrl: String(target), err },
"Error occurred while proxying request to target"
);
});
proxyServer.on("proxyReq", (proxyReq, req) => {
const from = req.originalUrl;
const { protocol, host, path } = proxyReq;
req.log.info(
{ from, to: `${proxyReq.protocol}//${proxyReq.host}${proxyReq.path}` },
{
from: req.originalUrl,
to: `${protocol}//${host}${path}`,
},
"Sending request to upstream API..."
);
});
proxyServer.on("proxyRes", (proxyRes: ProxiedResponse, req, _res) => {
const target = `${proxyRes.req.protocol}//${proxyRes.req.host}${proxyRes.req.path}`;
const statusCode = proxyRes.statusCode;
req.log.info({ target, statusCode }, "Got response from upstream API.");
const { protocol, host, path } = proxyRes.req;
req.log.info(
{
target: `${protocol}//${host}${path}`,
status: proxyRes.statusCode,
contentType: proxyRes.headers["content-type"],
},
"Got response from upstream API."
);
});
}

View File

@ -1,3 +1,4 @@
import { Request, Response } from "express";
import util from "util";
import zlib from "zlib";
import { sendProxyError } from "../common";
@ -7,13 +8,13 @@ const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
br: util.promisify(zlib.brotliDecompress),
text: (data: Buffer) => data,
};
type SupportedContentEncoding = keyof typeof DECODER_MAP;
const isSupportedContentEncoding = (
contentEncoding: string
): contentEncoding is keyof typeof DECODER_MAP => {
return contentEncoding in DECODER_MAP;
};
encoding: string
): encoding is SupportedContentEncoding => encoding in DECODER_MAP;
/**
* Handles the response from the upstream service and decodes the body if
@ -35,41 +36,39 @@ export const handleBlockingResponse: RawResponseBodyHandler = async (
throw err;
}
return new Promise<string>((resolve, reject) => {
return new Promise((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
let body = Buffer.concat(chunks);
let body: string | Buffer = Buffer.concat(chunks);
const rejectWithMessage = function (msg: string, err: Error) {
const error = `${msg} (${err.message})`;
req.log.warn({ stack: err.stack }, error);
sendProxyError(req, res, 500, "Internal Server Error", { error });
return reject(error);
};
const contentEncoding = proxyRes.headers["content-encoding"];
if (contentEncoding) {
if (isSupportedContentEncoding(contentEncoding)) {
const decoder = DECODER_MAP[contentEncoding];
// @ts-ignore - started failing after upgrading TypeScript, don't care
// as it was never a problem.
body = await decoder(body);
} else {
const error = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, error);
sendProxyError(req, res, 500, "Internal Server Error", {
error,
contentEncoding,
});
return reject(error);
const contentEncoding = proxyRes.headers["content-encoding"] ?? "text";
if (isSupportedContentEncoding(contentEncoding)) {
try {
body = (await DECODER_MAP[contentEncoding](body)).toString();
} catch (e) {
return rejectWithMessage(`Could not decode response body`, e);
}
} else {
return rejectWithMessage(
"API responded with unsupported content encoding",
new Error(`Unsupported content-encoding: ${contentEncoding}`)
);
}
try {
if (proxyRes.headers["content-type"]?.includes("application/json")) {
const json = JSON.parse(body.toString());
return resolve(json);
return resolve(JSON.parse(body));
}
return resolve(body.toString());
return resolve(body);
} catch (e) {
const msg = `Proxy received response with invalid JSON: ${e.message}`;
req.log.warn({ error: e.stack, key: req.key?.hash }, msg);
sendProxyError(req, res, 500, "Internal Server Error", { error: msg });
return reject(msg);
return rejectWithMessage("API responded with invalid JSON", e);
}
});
});

View File

@ -174,11 +174,11 @@ function getDecoder(options: {
logger: typeof logger;
contentType?: string;
}) {
const { api, contentType, input, logger } = options;
const { contentType, input, logger } = options;
if (contentType?.includes("application/vnd.amazon.eventstream")) {
return getAwsEventStreamDecoder({ input, logger });
} else if (api === "google-ai") {
return StreamArray.withParser();
} else if (contentType?.includes("application/json")) {
throw new Error("JSON streaming not supported, request SSE instead");
} else {
// Passthrough stream, but ensures split chunks across multi-byte characters
// are handled correctly.

View File

@ -135,15 +135,15 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
}
const { stack, message } = error;
const info = { stack, lastMiddleware, key: req.key?.hash };
const details = { stack, message, lastMiddleware, key: req.key?.hash };
const description = `Error while executing proxy response middleware: ${lastMiddleware} (${message})`;
if (res.headersSent) {
req.log.error(info, description);
req.log.error(details, description);
if (!res.writableEnded) res.end();
return;
} else {
req.log.error(info, description);
req.log.error(details, description);
res
.status(500)
.json({ error: "Internal server error", proxy_note: description });
@ -174,57 +174,52 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
) => {
const statusCode = proxyRes.statusCode || 500;
const statusMessage = proxyRes.statusMessage || "Internal Server Error";
let errorPayload: ProxiedErrorPayload;
const service = req.key!.service;
// Not an error, continue to next response handler
if (statusCode < 400) return;
// Parse the error response body
let errorPayload: ProxiedErrorPayload;
try {
assertJsonResponse(body);
errorPayload = body;
} catch (parseError) {
// Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy
const hash = req.key?.hash;
req.log.warn({ statusCode, statusMessage, key: hash }, parseError.message);
const strBody = String(body).slice(0, 128);
req.log.error({ statusCode, strBody }, "Error body is not JSON");
const errorObject = {
const details = {
error: parseError.message,
status: statusCode,
statusMessage,
proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service.`,
proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service. Response body: ${strBody}`,
};
sendProxyError(req, res, statusCode, statusMessage, errorObject);
sendProxyError(req, res, statusCode, statusMessage, details);
throw new HttpError(statusCode, parseError.message);
}
const service = req.key!.service;
// Extract the error type from the response body depending on the service
if (service === "gcp") {
if (Array.isArray(errorPayload)) {
errorPayload = errorPayload[0];
}
}
const errorType =
errorPayload.error?.code ||
errorPayload.error?.type ||
getAwsErrorType(proxyRes.headers["x-amzn-errortype"]);
req.log.warn(
{ statusCode, type: errorType, errorPayload, key: req.key?.hash },
`Received error response from upstream. (${proxyRes.statusMessage})`
{ statusCode, statusMessage, errorType, errorPayload, key: req.key?.hash },
`API returned an error.`
);
// TODO: split upstream error handling into separate modules for each service,
// this is out of control.
// Try to convert response body to a ProxiedErrorPayload with message/type
if (service === "aws") {
// Try to standardize the error format for AWS
errorPayload.error = { message: errorPayload.message, type: errorType };
delete errorPayload.message;
} else if (service === "gcp") {
// Try to standardize the error format for GCP
if (errorPayload.error?.code) {
// GCP Error
errorPayload.error = {
message: errorPayload.error.message,
type: errorPayload.error.status || errorPayload.error.code,
@ -232,6 +227,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
}
}
// Figure out what to do with the error
// TODO: separate error handling for each service
if (statusCode === 400) {
switch (service) {
case "openai":
@ -271,10 +268,6 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
errorType === "permission_error" &&
errorPayload.error?.message?.toLowerCase().includes("multimodal")
) {
req.log.warn(
{ key: req.key?.hash },
"This Anthropic key does not support multimodal prompts."
);
keyPool.update(req.key!, { allowsMultimodality: false });
await reenqueueRequest(req);
throw new RetryableError(
@ -342,7 +335,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
// Most likely model not found
switch (service) {
case "openai":
if (errorPayload.error?.code === "model_not_found") {
if (errorType === "model_not_found") {
const requestedModel = req.body.model;
const modelFamily = getOpenAIModelFamily(requestedModel);
errorPayload.proxy_note = `The key assigned to your prompt does not support the requested model (${requestedModel}, family: ${modelFamily}).`;
@ -353,22 +346,12 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
}
break;
case "anthropic":
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
break;
case "google-ai":
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
break;
case "mistral-ai":
errorPayload.proxy_note = `The requested Mistral AI model might not exist, or the key might not be provisioned for it.`;
break;
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
break;
case "gcp":
errorPayload.proxy_note = `The requested GCP resource might not exist, or the key might not have access to it.`;
break;
case "azure":
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
errorPayload.proxy_note = `The key assigned to your prompt does not support the requested model.`;
break;
default:
assertNever(service);
@ -377,7 +360,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
switch (service) {
case "aws":
if (
errorPayload.error?.type === "ServiceUnavailableException" &&
errorType === "ServiceUnavailableException" &&
errorPayload.error?.message?.match(/too many connections/i)
) {
errorPayload.proxy_note = `The requested AWS Bedrock model is overloaded. Try again in a few minutes, or try another model.`;
@ -391,7 +374,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
errorPayload.proxy_note = `Unrecognized error from upstream service.`;
}
// Some OAI errors contain the organization ID, which we don't want to reveal.
// Redact the OpenAI org id from the error message
if (errorPayload.error?.message) {
errorPayload.error.message = errorPayload.error.message.replace(
/org-.{24}/gm,
@ -399,9 +382,10 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
);
}
// Send the error to the client
sendProxyError(req, res, statusCode, statusMessage, errorPayload);
// This is bubbled up to onProxyRes's handler for logging but will not trigger
// a write to the response as `sendProxyError` has just done that.
// Re-throw the error to bubble up to onProxyRes's handler for logging
throw new HttpError(statusCode, errorPayload.error?.message);
};
@ -534,56 +518,6 @@ async function handleOpenAIRateLimitError(
// Per-minute request or token rate limit is exceeded, which we can retry
await reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
// WIP/nonfunctional
// case "tokens_usage_based":
// // Weird new rate limit type that seems limited to preview models.
// // Distinct from `tokens` type. Can be per-minute or per-day.
//
// // I've seen reports of this error for 500k tokens/day and 10k tokens/min.
// // 10k tokens per minute is problematic, because this is much less than
// // GPT4-Turbo's max context size for a single prompt and is effectively a
// // cap on the max context size for just that key+model, which the app is
// // not able to deal with.
//
// // Similarly if there is a 500k tokens per day limit and 450k tokens have
// // been used today, the max context for that key becomes 50k tokens until
// // the next day and becomes progressively smaller as more tokens are used.
//
// // To work around these keys we will first retry the request a few times.
// // After that we will reject the request, and if it's a per-day limit we
// // will also disable the key.
//
// // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per day: Limit 500000, Used 460000, Requested 50000"
// // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per min: Limit 10000, Requested 40000"
//
// const regex =
// /Rate limit reached for .+ in organization .+ on \w+ per (day|min): Limit (\d+)(?:, Used (\d+))?, Requested (\d+)/;
// const [, period, limit, used, requested] =
// errorPayload.error?.message?.match(regex) || [];
//
// req.log.warn(
// { key: req.key?.hash, period, limit, used, requested },
// "Received `tokens_usage_based` rate limit error from OpenAI."
// );
//
// if (!period || !limit || !requested) {
// errorPayload.proxy_note = `Unrecognized rate limit error from OpenAI. (${errorPayload.error?.message})`;
// break;
// }
//
// if (req.retryCount < 2) {
// await reenqueueRequest(req);
// throw new RetryableError("Rate-limited request re-enqueued.");
// }
//
// if (period === "min") {
// errorPayload.proxy_note = `Assigned key can't be used for prompts longer than ${limit} tokens, and no other keys are available right now. Reduce the length of your prompt or try again in a few minutes.`;
// } else {
// errorPayload.proxy_note = `Assigned key has reached its per-day request limit for this model. Try another model.`;
// }
//
// keyPool.markRateLimited(req.key!);
// break;
default:
errorPayload.proxy_note = `This is likely a temporary error with the API. Try again in a few seconds.`;
break;
@ -734,7 +668,6 @@ const trackKeyRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
keyPool.updateRateLimits(req.key!, proxyRes.headers);
};
const omittedHeaders = new Set<string>([
// Omit content-encoding because we will always decode the response body
"content-encoding",

View File

@ -20,7 +20,6 @@ type SSEStreamAdapterOptions = TransformOptions & {
*/
export class SSEStreamAdapter extends Transform {
private readonly isAwsStream;
private readonly isGoogleStream;
private api: APIFormat;
private partialMessage = "";
private textDecoder = new TextDecoder("utf8");
@ -30,7 +29,6 @@ export class SSEStreamAdapter extends Transform {
super({ ...options, objectMode: true });
this.isAwsStream =
options?.contentType === "application/vnd.amazon.eventstream";
this.isGoogleStream = options?.api === "google-ai";
this.api = options.api;
this.log = options.logger.child({ module: "sse-stream-adapter" });
}
@ -144,10 +142,6 @@ export class SSEStreamAdapter extends Transform {
// `data` is a Message object
const message = this.processAwsMessage(data);
if (message) this.push(message + "\n\n");
} else if (this.isGoogleStream) {
// `data` is an element from the Google AI JSON stream
const message = this.processGoogleObject(data);
if (message) this.push(message + "\n\n");
} else {
// `data` is a string, but possibly only a partial message
const fullMessages = (this.partialMessage + data).split(