From 22d7f966c686fb4202dbe9a29afad5e74028e3d0 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Sun, 29 Sep 2024 12:44:18 -0500 Subject: [PATCH] fixes for gemini api streaming --- src/proxy/google-ai.ts | 6 +- .../request/mutators/add-google-ai-key.ts | 12 +- .../request/proxy-middleware-factory.ts | 23 ++-- .../response/handle-blocking-response.ts | 55 ++++---- .../response/handle-streamed-response.ts | 6 +- src/proxy/middleware/response/index.ts | 117 ++++-------------- .../response/streaming/sse-stream-adapter.ts | 6 - 7 files changed, 79 insertions(+), 146 deletions(-) diff --git a/src/proxy/google-ai.ts b/src/proxy/google-ai.ts index 2a61a9b..9ccc1d2 100644 --- a/src/proxy/google-ai.ts +++ b/src/proxy/google-ai.ts @@ -149,7 +149,7 @@ function setStreamFlag(req: Request) { } /** - * Replaces requests for non-Google AI models with gemini-pro-1.5-latest. + * Replaces requests for non-Google AI models with gemini-1.5-pro-latest. * Also strips models/ from the beginning of the model IDs. **/ function maybeReassignModel(req: Request) { @@ -169,8 +169,8 @@ function maybeReassignModel(req: Request) { return; } - req.log.info({ requested }, "Reassigning model to gemini-pro-1.5-latest"); - req.body.model = "gemini-pro-1.5-latest"; + req.log.info({ requested }, "Reassigning model to gemini-1.5-pro-latest"); + req.body.model = "gemini-1.5-pro-latest"; } export const googleAI = googleAIRouter; diff --git a/src/proxy/middleware/request/mutators/add-google-ai-key.ts b/src/proxy/middleware/request/mutators/add-google-ai-key.ts index 4643a79..15a25a5 100644 --- a/src/proxy/middleware/request/mutators/add-google-ai-key.ts +++ b/src/proxy/middleware/request/mutators/add-google-ai-key.ts @@ -1,17 +1,17 @@ import { keyPool } from "../../../../shared/key-management"; -import { ProxyReqMutator} from "../index"; +import { ProxyReqMutator } from "../index"; export const addGoogleAIKey: ProxyReqMutator = (manager) => { const req = manager.request; const inboundValid = req.inboundApi === "openai" || req.inboundApi === "google-ai"; const outboundValid = req.outboundApi === "google-ai"; - + const serviceValid = req.service === "google-ai"; if (!inboundValid || !outboundValid || !serviceValid) { throw new Error("addGoogleAIKey called on invalid request"); } - + const model = req.body.model; const key = keyPool.get(model, "google-ai"); manager.setKey(key); @@ -20,7 +20,7 @@ export const addGoogleAIKey: ProxyReqMutator = (manager) => { { key: key.hash, model, stream: req.isStreaming }, "Assigned Google AI API key to request" ); - + // https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY // https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY} const payload = { ...req.body, stream: undefined, model: undefined }; @@ -33,8 +33,8 @@ export const addGoogleAIKey: ProxyReqMutator = (manager) => { protocol: "https:", hostname: "generativelanguage.googleapis.com", path: `/v1beta/models/${model}:${ - req.isStreaming ? "streamGenerateContent" : "generateContent" - }?key=${key.key}`, + req.isStreaming ? "streamGenerateContent?alt=sse&" : "generateContent?" + }key=${key.key}`, headers: { ["host"]: `generativelanguage.googleapis.com`, ["content-type"]: "application/json", diff --git a/src/proxy/middleware/request/proxy-middleware-factory.ts b/src/proxy/middleware/request/proxy-middleware-factory.ts index 5e50c6f..07bce4f 100644 --- a/src/proxy/middleware/request/proxy-middleware-factory.ts +++ b/src/proxy/middleware/request/proxy-middleware-factory.ts @@ -100,23 +100,30 @@ export function createQueuedProxyMiddleware({ type ProxiedResponse = http.IncomingMessage & Response & any; function pinoLoggerPlugin(proxyServer: ProxyServer) { proxyServer.on("error", (err, req, res, target) => { - const originalUrl = req.originalUrl; - const targetUrl = target?.toString(); req.log.error( - { originalUrl, targetUrl, err }, + { originalUrl: req.originalUrl, targetUrl: String(target), err }, "Error occurred while proxying request to target" ); }); proxyServer.on("proxyReq", (proxyReq, req) => { - const from = req.originalUrl; + const { protocol, host, path } = proxyReq; req.log.info( - { from, to: `${proxyReq.protocol}//${proxyReq.host}${proxyReq.path}` }, + { + from: req.originalUrl, + to: `${protocol}//${host}${path}`, + }, "Sending request to upstream API..." ); }); proxyServer.on("proxyRes", (proxyRes: ProxiedResponse, req, _res) => { - const target = `${proxyRes.req.protocol}//${proxyRes.req.host}${proxyRes.req.path}`; - const statusCode = proxyRes.statusCode; - req.log.info({ target, statusCode }, "Got response from upstream API."); + const { protocol, host, path } = proxyRes.req; + req.log.info( + { + target: `${protocol}//${host}${path}`, + status: proxyRes.statusCode, + contentType: proxyRes.headers["content-type"], + }, + "Got response from upstream API." + ); }); } diff --git a/src/proxy/middleware/response/handle-blocking-response.ts b/src/proxy/middleware/response/handle-blocking-response.ts index b1b420d..6253a49 100644 --- a/src/proxy/middleware/response/handle-blocking-response.ts +++ b/src/proxy/middleware/response/handle-blocking-response.ts @@ -1,3 +1,4 @@ +import { Request, Response } from "express"; import util from "util"; import zlib from "zlib"; import { sendProxyError } from "../common"; @@ -7,13 +8,13 @@ const DECODER_MAP = { gzip: util.promisify(zlib.gunzip), deflate: util.promisify(zlib.inflate), br: util.promisify(zlib.brotliDecompress), + text: (data: Buffer) => data, }; +type SupportedContentEncoding = keyof typeof DECODER_MAP; const isSupportedContentEncoding = ( - contentEncoding: string -): contentEncoding is keyof typeof DECODER_MAP => { - return contentEncoding in DECODER_MAP; -}; + encoding: string +): encoding is SupportedContentEncoding => encoding in DECODER_MAP; /** * Handles the response from the upstream service and decodes the body if @@ -35,41 +36,39 @@ export const handleBlockingResponse: RawResponseBodyHandler = async ( throw err; } - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { let chunks: Buffer[] = []; proxyRes.on("data", (chunk) => chunks.push(chunk)); proxyRes.on("end", async () => { - let body = Buffer.concat(chunks); + let body: string | Buffer = Buffer.concat(chunks); + const rejectWithMessage = function (msg: string, err: Error) { + const error = `${msg} (${err.message})`; + req.log.warn({ stack: err.stack }, error); + sendProxyError(req, res, 500, "Internal Server Error", { error }); + return reject(error); + }; - const contentEncoding = proxyRes.headers["content-encoding"]; - if (contentEncoding) { - if (isSupportedContentEncoding(contentEncoding)) { - const decoder = DECODER_MAP[contentEncoding]; - // @ts-ignore - started failing after upgrading TypeScript, don't care - // as it was never a problem. - body = await decoder(body); - } else { - const error = `Proxy received response with unsupported content-encoding: ${contentEncoding}`; - req.log.warn({ contentEncoding, key: req.key?.hash }, error); - sendProxyError(req, res, 500, "Internal Server Error", { - error, - contentEncoding, - }); - return reject(error); + const contentEncoding = proxyRes.headers["content-encoding"] ?? "text"; + if (isSupportedContentEncoding(contentEncoding)) { + try { + body = (await DECODER_MAP[contentEncoding](body)).toString(); + } catch (e) { + return rejectWithMessage(`Could not decode response body`, e); } + } else { + return rejectWithMessage( + "API responded with unsupported content encoding", + new Error(`Unsupported content-encoding: ${contentEncoding}`) + ); } try { if (proxyRes.headers["content-type"]?.includes("application/json")) { - const json = JSON.parse(body.toString()); - return resolve(json); + return resolve(JSON.parse(body)); } - return resolve(body.toString()); + return resolve(body); } catch (e) { - const msg = `Proxy received response with invalid JSON: ${e.message}`; - req.log.warn({ error: e.stack, key: req.key?.hash }, msg); - sendProxyError(req, res, 500, "Internal Server Error", { error: msg }); - return reject(msg); + return rejectWithMessage("API responded with invalid JSON", e); } }); }); diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index a15eb51..5394b98 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -174,11 +174,11 @@ function getDecoder(options: { logger: typeof logger; contentType?: string; }) { - const { api, contentType, input, logger } = options; + const { contentType, input, logger } = options; if (contentType?.includes("application/vnd.amazon.eventstream")) { return getAwsEventStreamDecoder({ input, logger }); - } else if (api === "google-ai") { - return StreamArray.withParser(); + } else if (contentType?.includes("application/json")) { + throw new Error("JSON streaming not supported, request SSE instead"); } else { // Passthrough stream, but ensures split chunks across multi-byte characters // are handled correctly. diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index ed6a902..a6e227c 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -135,15 +135,15 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => { } const { stack, message } = error; - const info = { stack, lastMiddleware, key: req.key?.hash }; + const details = { stack, message, lastMiddleware, key: req.key?.hash }; const description = `Error while executing proxy response middleware: ${lastMiddleware} (${message})`; if (res.headersSent) { - req.log.error(info, description); + req.log.error(details, description); if (!res.writableEnded) res.end(); return; } else { - req.log.error(info, description); + req.log.error(details, description); res .status(500) .json({ error: "Internal server error", proxy_note: description }); @@ -174,57 +174,52 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( ) => { const statusCode = proxyRes.statusCode || 500; const statusMessage = proxyRes.statusMessage || "Internal Server Error"; - let errorPayload: ProxiedErrorPayload; - + const service = req.key!.service; + // Not an error, continue to next response handler if (statusCode < 400) return; + // Parse the error response body + let errorPayload: ProxiedErrorPayload; try { assertJsonResponse(body); errorPayload = body; } catch (parseError) { - // Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy - const hash = req.key?.hash; - req.log.warn({ statusCode, statusMessage, key: hash }, parseError.message); + const strBody = String(body).slice(0, 128); + req.log.error({ statusCode, strBody }, "Error body is not JSON"); - const errorObject = { + const details = { error: parseError.message, status: statusCode, statusMessage, - proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service.`, + proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service. Response body: ${strBody}`, }; - sendProxyError(req, res, statusCode, statusMessage, errorObject); + sendProxyError(req, res, statusCode, statusMessage, details); throw new HttpError(statusCode, parseError.message); } - const service = req.key!.service; + // Extract the error type from the response body depending on the service if (service === "gcp") { if (Array.isArray(errorPayload)) { errorPayload = errorPayload[0]; } } - const errorType = errorPayload.error?.code || errorPayload.error?.type || getAwsErrorType(proxyRes.headers["x-amzn-errortype"]); req.log.warn( - { statusCode, type: errorType, errorPayload, key: req.key?.hash }, - `Received error response from upstream. (${proxyRes.statusMessage})` + { statusCode, statusMessage, errorType, errorPayload, key: req.key?.hash }, + `API returned an error.` ); - // TODO: split upstream error handling into separate modules for each service, - // this is out of control. - + // Try to convert response body to a ProxiedErrorPayload with message/type if (service === "aws") { - // Try to standardize the error format for AWS errorPayload.error = { message: errorPayload.message, type: errorType }; delete errorPayload.message; } else if (service === "gcp") { - // Try to standardize the error format for GCP if (errorPayload.error?.code) { - // GCP Error errorPayload.error = { message: errorPayload.error.message, type: errorPayload.error.status || errorPayload.error.code, @@ -232,6 +227,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( } } + // Figure out what to do with the error + // TODO: separate error handling for each service if (statusCode === 400) { switch (service) { case "openai": @@ -271,10 +268,6 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( errorType === "permission_error" && errorPayload.error?.message?.toLowerCase().includes("multimodal") ) { - req.log.warn( - { key: req.key?.hash }, - "This Anthropic key does not support multimodal prompts." - ); keyPool.update(req.key!, { allowsMultimodality: false }); await reenqueueRequest(req); throw new RetryableError( @@ -342,7 +335,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( // Most likely model not found switch (service) { case "openai": - if (errorPayload.error?.code === "model_not_found") { + if (errorType === "model_not_found") { const requestedModel = req.body.model; const modelFamily = getOpenAIModelFamily(requestedModel); errorPayload.proxy_note = `The key assigned to your prompt does not support the requested model (${requestedModel}, family: ${modelFamily}).`; @@ -353,22 +346,12 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( } break; case "anthropic": - errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`; - break; case "google-ai": - errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`; - break; case "mistral-ai": - errorPayload.proxy_note = `The requested Mistral AI model might not exist, or the key might not be provisioned for it.`; - break; case "aws": - errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`; - break; case "gcp": - errorPayload.proxy_note = `The requested GCP resource might not exist, or the key might not have access to it.`; - break; case "azure": - errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`; + errorPayload.proxy_note = `The key assigned to your prompt does not support the requested model.`; break; default: assertNever(service); @@ -377,7 +360,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( switch (service) { case "aws": if ( - errorPayload.error?.type === "ServiceUnavailableException" && + errorType === "ServiceUnavailableException" && errorPayload.error?.message?.match(/too many connections/i) ) { errorPayload.proxy_note = `The requested AWS Bedrock model is overloaded. Try again in a few minutes, or try another model.`; @@ -391,7 +374,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( errorPayload.proxy_note = `Unrecognized error from upstream service.`; } - // Some OAI errors contain the organization ID, which we don't want to reveal. + // Redact the OpenAI org id from the error message if (errorPayload.error?.message) { errorPayload.error.message = errorPayload.error.message.replace( /org-.{24}/gm, @@ -399,9 +382,10 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( ); } + // Send the error to the client sendProxyError(req, res, statusCode, statusMessage, errorPayload); - // This is bubbled up to onProxyRes's handler for logging but will not trigger - // a write to the response as `sendProxyError` has just done that. + + // Re-throw the error to bubble up to onProxyRes's handler for logging throw new HttpError(statusCode, errorPayload.error?.message); }; @@ -534,56 +518,6 @@ async function handleOpenAIRateLimitError( // Per-minute request or token rate limit is exceeded, which we can retry await reenqueueRequest(req); throw new RetryableError("Rate-limited request re-enqueued."); - // WIP/nonfunctional - // case "tokens_usage_based": - // // Weird new rate limit type that seems limited to preview models. - // // Distinct from `tokens` type. Can be per-minute or per-day. - // - // // I've seen reports of this error for 500k tokens/day and 10k tokens/min. - // // 10k tokens per minute is problematic, because this is much less than - // // GPT4-Turbo's max context size for a single prompt and is effectively a - // // cap on the max context size for just that key+model, which the app is - // // not able to deal with. - // - // // Similarly if there is a 500k tokens per day limit and 450k tokens have - // // been used today, the max context for that key becomes 50k tokens until - // // the next day and becomes progressively smaller as more tokens are used. - // - // // To work around these keys we will first retry the request a few times. - // // After that we will reject the request, and if it's a per-day limit we - // // will also disable the key. - // - // // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per day: Limit 500000, Used 460000, Requested 50000" - // // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per min: Limit 10000, Requested 40000" - // - // const regex = - // /Rate limit reached for .+ in organization .+ on \w+ per (day|min): Limit (\d+)(?:, Used (\d+))?, Requested (\d+)/; - // const [, period, limit, used, requested] = - // errorPayload.error?.message?.match(regex) || []; - // - // req.log.warn( - // { key: req.key?.hash, period, limit, used, requested }, - // "Received `tokens_usage_based` rate limit error from OpenAI." - // ); - // - // if (!period || !limit || !requested) { - // errorPayload.proxy_note = `Unrecognized rate limit error from OpenAI. (${errorPayload.error?.message})`; - // break; - // } - // - // if (req.retryCount < 2) { - // await reenqueueRequest(req); - // throw new RetryableError("Rate-limited request re-enqueued."); - // } - // - // if (period === "min") { - // errorPayload.proxy_note = `Assigned key can't be used for prompts longer than ${limit} tokens, and no other keys are available right now. Reduce the length of your prompt or try again in a few minutes.`; - // } else { - // errorPayload.proxy_note = `Assigned key has reached its per-day request limit for this model. Try another model.`; - // } - // - // keyPool.markRateLimited(req.key!); - // break; default: errorPayload.proxy_note = `This is likely a temporary error with the API. Try again in a few seconds.`; break; @@ -734,7 +668,6 @@ const trackKeyRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => { keyPool.updateRateLimits(req.key!, proxyRes.headers); }; - const omittedHeaders = new Set([ // Omit content-encoding because we will always decode the response body "content-encoding", diff --git a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts index f74bb9c..ce27355 100644 --- a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts +++ b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts @@ -20,7 +20,6 @@ type SSEStreamAdapterOptions = TransformOptions & { */ export class SSEStreamAdapter extends Transform { private readonly isAwsStream; - private readonly isGoogleStream; private api: APIFormat; private partialMessage = ""; private textDecoder = new TextDecoder("utf8"); @@ -30,7 +29,6 @@ export class SSEStreamAdapter extends Transform { super({ ...options, objectMode: true }); this.isAwsStream = options?.contentType === "application/vnd.amazon.eventstream"; - this.isGoogleStream = options?.api === "google-ai"; this.api = options.api; this.log = options.logger.child({ module: "sse-stream-adapter" }); } @@ -144,10 +142,6 @@ export class SSEStreamAdapter extends Transform { // `data` is a Message object const message = this.processAwsMessage(data); if (message) this.push(message + "\n\n"); - } else if (this.isGoogleStream) { - // `data` is an element from the Google AI JSON stream - const message = this.processGoogleObject(data); - if (message) this.push(message + "\n\n"); } else { // `data` is a string, but possibly only a partial message const fullMessages = (this.partialMessage + data).split(