diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 32e4819..ab4743e 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -7,12 +7,9 @@ import { ipLimiter } from "./rate-limit"; import { handleProxyError } from "./middleware/common"; import { addKey, - applyQuotaLimits, addAnthropicPreamble, - blockZoomerOrigins, createPreprocessorMiddleware, finalizeBody, - stripHeaders, createOnProxyReqHandler, } from "./middleware/request"; import { @@ -137,14 +134,7 @@ const anthropicProxy = createQueueMiddleware({ logger, on: { proxyReq: createOnProxyReqHandler({ - pipeline: [ - applyQuotaLimits, - addKey, - addAnthropicPreamble, - blockZoomerOrigins, - stripHeaders, - finalizeBody, - ], + pipeline: [addKey, addAnthropicPreamble, finalizeBody], }), proxyRes: createOnProxyResHandler([anthropicResponseHandler]), error: handleProxyError, diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index e30d02f..0b1d4d1 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -7,19 +7,18 @@ import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit"; import { handleProxyError } from "./middleware/common"; import { - applyQuotaLimits, createPreprocessorMiddleware, - stripHeaders, signAwsRequest, finalizeSignedRequest, createOnProxyReqHandler, - blockZoomerOrigins, } from "./middleware/request"; import { ProxyResHandlerWithBody, createOnProxyResHandler, } from "./middleware/response"; +const LATEST_AWS_V2_MINOR_VERSION = "1"; + let modelsCache: any = null; let modelsCacheTime = 0; @@ -133,14 +132,7 @@ const awsProxy = createQueueMiddleware({ selfHandleResponse: true, logger, on: { - proxyReq: createOnProxyReqHandler({ - pipeline: [ - applyQuotaLimits, - blockZoomerOrigins, - stripHeaders, - finalizeSignedRequest, - ], - }), + proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), proxyRes: createOnProxyResHandler([awsResponseHandler]), error: handleProxyError, }, @@ -178,20 +170,15 @@ awsRouter.post( * - frontends sending OpenAI model names because they expect the proxy to * translate them */ -const LATEST_AWS_V2_MINOR_VERSION = '1'; - function maybeReassignModel(req: Request) { const model = req.body.model; - // If the string already includes "anthropic.claude", return it unmodified + // If client already specified an AWS Claude model ID, use it if (model.includes("anthropic.claude")) { return; } - // Define a regular expression pattern to match the Claude version strings const pattern = /^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?$/i; - - // Execute the pattern on the model string const match = model.match(pattern); // If there's no match, return the latest v2 model @@ -200,34 +187,30 @@ function maybeReassignModel(req: Request) { return; } - // Extract parts of the version string - const [, , instant, v, major, , minor] = match; + const [, , instant, , major, , minor] = match; - // If 'instant' is part of the version, return the fixed instant model string if (instant) { - req.body.model = 'anthropic.claude-instant-v1'; + req.body.model = "anthropic.claude-instant-v1"; return; } - // If the major version is '1', return the fixed v1 model string - if (major === '1') { - req.body.model = 'anthropic.claude-v1'; + // There's only one v1 model + if (major === "1") { + req.body.model = "anthropic.claude-v1"; return; } - // If the major version is '2' - if (major === '2') { - // If the minor version is explicitly '0', return "anthropic.claude-v2" which is claude-2.0 - if (minor === '0') { - req.body.model = 'anthropic.claude-v2'; + // Try to map Anthropic API v2 models to AWS v2 models + if (major === "2") { + if (minor === "0") { + req.body.model = "anthropic.claude-v2"; return; } - // Otherwise, return the v2 model string with the latest minor version req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`; return; } - // If none of the above conditions are met, return the latest v2 model by default + // Fallback to latest v2 model req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`; return; } diff --git a/src/proxy/azure.ts b/src/proxy/azure.ts index 45b9e95..80daace 100644 --- a/src/proxy/azure.ts +++ b/src/proxy/azure.ts @@ -13,19 +13,15 @@ import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit"; import { handleProxyError } from "./middleware/common"; import { - applyQuotaLimits, - blockZoomerOrigins, + addAzureKey, createOnProxyReqHandler, createPreprocessorMiddleware, finalizeSignedRequest, - limitCompletions, - stripHeaders, } from "./middleware/request"; import { createOnProxyResHandler, ProxyResHandlerWithBody, } from "./middleware/response"; -import { addAzureKey } from "./middleware/request/add-azure-key"; let modelsCache: any = null; let modelsCacheTime = 0; @@ -109,15 +105,7 @@ const azureOpenAIProxy = createQueueMiddleware({ selfHandleResponse: true, logger, on: { - proxyReq: createOnProxyReqHandler({ - pipeline: [ - applyQuotaLimits, - limitCompletions, - blockZoomerOrigins, - stripHeaders, - finalizeSignedRequest, - ], - }), + proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]), error: handleProxyError, }, diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 52091c0..4669249 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -4,7 +4,7 @@ import { ZodError } from "zod"; import { generateErrorMessage } from "zod-error"; import { buildFakeSse } from "../../shared/streaming"; import { assertNever } from "../../shared/utils"; -import { QuotaExceededError } from "./request/apply-quota-limits"; +import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits"; const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions"; const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions"; diff --git a/src/proxy/middleware/request/index.ts b/src/proxy/middleware/request/index.ts index 8b31ecf..161ad03 100644 --- a/src/proxy/middleware/request/index.ts +++ b/src/proxy/middleware/request/index.ts @@ -2,29 +2,30 @@ import type { Request } from "express"; import type { ClientRequest } from "http"; import type { ProxyReqCallback } from "http-proxy"; -export { createOnProxyReqHandler } from "./rewrite"; +export { createOnProxyReqHandler } from "./onproxyreq-factory"; export { createPreprocessorMiddleware, createEmbeddingsPreprocessorMiddleware, -} from "./preprocess"; +} from "./preprocessor-factory"; // Express middleware (runs before http-proxy-middleware, can be async) -export { applyQuotaLimits } from "./apply-quota-limits"; -export { validateContextSize } from "./validate-context-size"; -export { countPromptTokens } from "./count-prompt-tokens"; -export { languageFilter } from "./language-filter"; -export { setApiFormat } from "./set-api-format"; -export { signAwsRequest } from "./sign-aws-request"; -export { transformOutboundPayload } from "./transform-outbound-payload"; +export { addAzureKey } from "./preprocessors/add-azure-key"; +export { applyQuotaLimits } from "./preprocessors/apply-quota-limits"; +export { validateContextSize } from "./preprocessors/validate-context-size"; +export { countPromptTokens } from "./preprocessors/count-prompt-tokens"; +export { languageFilter } from "./preprocessors/language-filter"; +export { setApiFormat } from "./preprocessors/set-api-format"; +export { signAwsRequest } from "./preprocessors/sign-aws-request"; +export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload"; -// HPM middleware (runs on onProxyReq, cannot be async) -export { addKey, addKeyForEmbeddingsRequest } from "./add-key"; -export { addAnthropicPreamble } from "./add-anthropic-preamble"; -export { blockZoomerOrigins } from "./block-zoomer-origins"; -export { finalizeBody } from "./finalize-body"; -export { finalizeSignedRequest } from "./finalize-signed-request"; -export { limitCompletions } from "./limit-completions"; -export { stripHeaders } from "./strip-headers"; +// http-proxy-middleware callbacks (runs on onProxyReq, cannot be async) +export { addKey, addKeyForEmbeddingsRequest } from "./onproxyreq/add-key"; +export { addAnthropicPreamble } from "./onproxyreq/add-anthropic-preamble"; +export { blockZoomerOrigins } from "./onproxyreq/block-zoomer-origins"; +export { checkModelFamily } from "./onproxyreq/check-model-family"; +export { finalizeBody } from "./onproxyreq/finalize-body"; +export { finalizeSignedRequest } from "./onproxyreq/finalize-signed-request"; +export { stripHeaders } from "./onproxyreq/strip-headers"; /** * Middleware that runs prior to the request being handled by http-proxy- @@ -43,7 +44,7 @@ export { stripHeaders } from "./strip-headers"; export type RequestPreprocessor = (req: Request) => void | Promise; /** - * Middleware that runs immediately before the request is sent to the API in + * Callbacks that run immediately before the request is sent to the API in * response to http-proxy-middleware's `proxyReq` event. * * Async functions cannot be used here as HPM's event emitter is not async and @@ -53,7 +54,7 @@ export type RequestPreprocessor = (req: Request) => void | Promise; * first attempt is rate limited and the request is automatically retried by the * request queue middleware. */ -export type ProxyRequestMiddleware = ProxyReqCallback; +export type HPMRequestCallback = ProxyReqCallback; export const forceModel = (model: string) => (req: Request) => void (req.body.model = model); diff --git a/src/proxy/middleware/request/limit-completions.ts b/src/proxy/middleware/request/limit-completions.ts deleted file mode 100644 index 44f583b..0000000 --- a/src/proxy/middleware/request/limit-completions.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { isTextGenerationRequest } from "../common"; -import { ProxyRequestMiddleware } from "."; - -/** - * Don't allow multiple text completions to be requested to prevent abuse. - * OpenAI-only, Anthropic provides no such parameter. - **/ -export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => { - if (isTextGenerationRequest(req) && req.outboundApi === "openai") { - const originalN = req.body?.n || 1; - req.body.n = 1; - if (originalN !== req.body.n) { - req.log.warn(`Limiting completion choices from ${originalN} to 1`); - } - } -}; diff --git a/src/proxy/middleware/request/onproxyreq-factory.ts b/src/proxy/middleware/request/onproxyreq-factory.ts new file mode 100644 index 0000000..df44d04 --- /dev/null +++ b/src/proxy/middleware/request/onproxyreq-factory.ts @@ -0,0 +1,43 @@ +import { + applyQuotaLimits, + blockZoomerOrigins, + checkModelFamily, + HPMRequestCallback, + stripHeaders, +} from "./index"; + +type ProxyReqHandlerFactoryOptions = { pipeline: HPMRequestCallback[] }; + +/** + * Returns an http-proxy-middleware request handler that runs the given set of + * onProxyReq callback functions in sequence. + * + * These will run each time a request is proxied, including on automatic retries + * by the queue after encountering a rate limit. + */ +export const createOnProxyReqHandler = ({ + pipeline, +}: ProxyReqHandlerFactoryOptions): HPMRequestCallback => { + const callbackPipeline = [ + checkModelFamily, + applyQuotaLimits, + blockZoomerOrigins, + stripHeaders, + ...pipeline, + ]; + return (proxyReq, req, res, options) => { + // The streaming flag must be set before any other onProxyReq handler runs, + // as it may influence the behavior of subsequent handlers. + // Image generation requests can't be streamed. + req.isStreaming = req.body.stream === true || req.body.stream === "true"; + req.body.stream = req.isStreaming; + + try { + for (const fn of callbackPipeline) { + fn(proxyReq, req, res, options); + } + } catch (error) { + proxyReq.destroy(error); + } + }; +}; diff --git a/src/proxy/middleware/request/add-anthropic-preamble.ts b/src/proxy/middleware/request/onproxyreq/add-anthropic-preamble.ts similarity index 78% rename from src/proxy/middleware/request/add-anthropic-preamble.ts rename to src/proxy/middleware/request/onproxyreq/add-anthropic-preamble.ts index cdab4f2..2d63e79 100644 --- a/src/proxy/middleware/request/add-anthropic-preamble.ts +++ b/src/proxy/middleware/request/onproxyreq/add-anthropic-preamble.ts @@ -1,13 +1,13 @@ -import { AnthropicKey, Key } from "../../../shared/key-management"; -import { isTextGenerationRequest } from "../common"; -import { ProxyRequestMiddleware } from "."; +import { AnthropicKey, Key } from "../../../../shared/key-management"; +import { isTextGenerationRequest } from "../../common"; +import { HPMRequestCallback } from "../index"; /** * Some keys require the prompt to start with `\n\nHuman:`. There is no way to * know this without trying to send the request and seeing if it fails. If a * key is marked as requiring a preamble, it will be added here. */ -export const addAnthropicPreamble: ProxyRequestMiddleware = ( +export const addAnthropicPreamble: HPMRequestCallback = ( _proxyReq, req ) => { diff --git a/src/proxy/middleware/request/add-key.ts b/src/proxy/middleware/request/onproxyreq/add-key.ts similarity index 90% rename from src/proxy/middleware/request/add-key.ts rename to src/proxy/middleware/request/onproxyreq/add-key.ts index 49a7e88..3acae5a 100644 --- a/src/proxy/middleware/request/add-key.ts +++ b/src/proxy/middleware/request/onproxyreq/add-key.ts @@ -1,10 +1,10 @@ -import { Key, OpenAIKey, keyPool } from "../../../shared/key-management"; -import { isEmbeddingsRequest } from "../common"; -import { ProxyRequestMiddleware } from "."; -import { assertNever } from "../../../shared/utils"; +import { Key, OpenAIKey, keyPool } from "../../../../shared/key-management"; +import { isEmbeddingsRequest } from "../../common"; +import { HPMRequestCallback } from "../index"; +import { assertNever } from "../../../../shared/utils"; /** Add a key that can service this request to the request object. */ -export const addKey: ProxyRequestMiddleware = (proxyReq, req) => { +export const addKey: HPMRequestCallback = (proxyReq, req) => { let assignedKey: Key; if (!req.inboundApi || !req.outboundApi) { @@ -97,7 +97,7 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => { * Special case for embeddings requests which don't go through the normal * request pipeline. */ -export const addKeyForEmbeddingsRequest: ProxyRequestMiddleware = ( +export const addKeyForEmbeddingsRequest: HPMRequestCallback = ( proxyReq, req ) => { diff --git a/src/proxy/middleware/request/block-zoomer-origins.ts b/src/proxy/middleware/request/onproxyreq/block-zoomer-origins.ts similarity index 88% rename from src/proxy/middleware/request/block-zoomer-origins.ts rename to src/proxy/middleware/request/onproxyreq/block-zoomer-origins.ts index 9efa404..0c8360d 100644 --- a/src/proxy/middleware/request/block-zoomer-origins.ts +++ b/src/proxy/middleware/request/onproxyreq/block-zoomer-origins.ts @@ -1,4 +1,4 @@ -import { ProxyRequestMiddleware } from "."; +import { HPMRequestCallback } from "../index"; const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(","); @@ -13,7 +13,7 @@ class ForbiddenError extends Error { * Blocks requests from Janitor AI users with a fake, scary error message so I * stop getting emails asking for tech support. */ -export const blockZoomerOrigins: ProxyRequestMiddleware = (_proxyReq, req) => { +export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => { const origin = req.headers.origin || req.headers.referer; if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) { // Venus-derivatives send a test prompt to check if the proxy is working. diff --git a/src/proxy/middleware/request/onproxyreq/check-model-family.ts b/src/proxy/middleware/request/onproxyreq/check-model-family.ts new file mode 100644 index 0000000..1460bee --- /dev/null +++ b/src/proxy/middleware/request/onproxyreq/check-model-family.ts @@ -0,0 +1,13 @@ +import { HPMRequestCallback } from "../index"; +import { config } from "../../../../config"; +import { getModelFamilyForRequest } from "../../../../shared/models"; + +/** + * Ensures the selected model family is enabled by the proxy configuration. + **/ +export const checkModelFamily: HPMRequestCallback = (proxyReq, req) => { + const family = getModelFamilyForRequest(req); + if (!config.allowedModelFamilies.includes(family)) { + throw new Error(`Model family ${family} is not permitted on this proxy`); + } +}; diff --git a/src/proxy/middleware/request/finalize-body.ts b/src/proxy/middleware/request/onproxyreq/finalize-body.ts similarity index 83% rename from src/proxy/middleware/request/finalize-body.ts rename to src/proxy/middleware/request/onproxyreq/finalize-body.ts index ac90e96..21d56d1 100644 --- a/src/proxy/middleware/request/finalize-body.ts +++ b/src/proxy/middleware/request/onproxyreq/finalize-body.ts @@ -1,8 +1,8 @@ import { fixRequestBody } from "http-proxy-middleware"; -import type { ProxyRequestMiddleware } from "."; +import type { HPMRequestCallback } from "../index"; /** Finalize the rewritten request body. Must be the last rewriter. */ -export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => { +export const finalizeBody: HPMRequestCallback = (proxyReq, req) => { if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) { // For image generation requests, remove stream flag. if (req.outboundApi === "openai-image") { diff --git a/src/proxy/middleware/request/finalize-signed-request.ts b/src/proxy/middleware/request/onproxyreq/finalize-signed-request.ts similarity index 87% rename from src/proxy/middleware/request/finalize-signed-request.ts rename to src/proxy/middleware/request/onproxyreq/finalize-signed-request.ts index d8c6622..52e8482 100644 --- a/src/proxy/middleware/request/finalize-signed-request.ts +++ b/src/proxy/middleware/request/onproxyreq/finalize-signed-request.ts @@ -1,11 +1,11 @@ -import type { ProxyRequestMiddleware } from "."; +import type { HPMRequestCallback } from "../index"; /** * For AWS/Azure requests, the body is signed earlier in the request pipeline, * before the proxy middleware. This function just assigns the path and headers * to the proxy request. */ -export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => { +export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => { if (!req.signedRequest) { throw new Error("Expected req.signedRequest to be set"); } diff --git a/src/proxy/middleware/request/strip-headers.ts b/src/proxy/middleware/request/onproxyreq/strip-headers.ts similarity index 77% rename from src/proxy/middleware/request/strip-headers.ts rename to src/proxy/middleware/request/onproxyreq/strip-headers.ts index 793aae0..4c39224 100644 --- a/src/proxy/middleware/request/strip-headers.ts +++ b/src/proxy/middleware/request/onproxyreq/strip-headers.ts @@ -1,10 +1,10 @@ -import { ProxyRequestMiddleware } from "."; +import { HPMRequestCallback } from "../index"; /** * Removes origin and referer headers before sending the request to the API for * privacy reasons. **/ -export const stripHeaders: ProxyRequestMiddleware = (proxyReq) => { +export const stripHeaders: HPMRequestCallback = (proxyReq) => { proxyReq.setHeader("origin", ""); proxyReq.setHeader("referer", ""); diff --git a/src/proxy/middleware/request/preprocess.ts b/src/proxy/middleware/request/preprocessor-factory.ts similarity index 89% rename from src/proxy/middleware/request/preprocess.ts rename to src/proxy/middleware/request/preprocessor-factory.ts index be515a8..b9caddb 100644 --- a/src/proxy/middleware/request/preprocess.ts +++ b/src/proxy/middleware/request/preprocessor-factory.ts @@ -29,6 +29,14 @@ type RequestPreprocessorOptions = { /** * Returns a middleware function that processes the request body into the given * API format, and then sequentially runs the given additional preprocessors. + * + * These run first in the request lifecycle, a single time per request before it + * is added to the request queue. They aren't run again if the request is + * re-attempted after a rate limit. + * + * To run a preprocessor on every re-attempt, pass it to createQueueMiddleware. + * It will run after these preprocessors, but before the request is sent to + * http-proxy-middleware. */ export const createPreprocessorMiddleware = ( apiFormat: Parameters[0], diff --git a/src/proxy/middleware/request/add-azure-key.ts b/src/proxy/middleware/request/preprocessors/add-azure-key.ts similarity index 92% rename from src/proxy/middleware/request/add-azure-key.ts rename to src/proxy/middleware/request/preprocessors/add-azure-key.ts index 5c34d34..2a2e8f2 100644 --- a/src/proxy/middleware/request/add-azure-key.ts +++ b/src/proxy/middleware/request/preprocessors/add-azure-key.ts @@ -1,5 +1,5 @@ -import { AzureOpenAIKey, keyPool } from "../../../shared/key-management"; -import { RequestPreprocessor } from "."; +import { AzureOpenAIKey, keyPool } from "../../../../shared/key-management"; +import { RequestPreprocessor } from "../index"; export const addAzureKey: RequestPreprocessor = (req) => { const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai"; diff --git a/src/proxy/middleware/request/apply-quota-limits.ts b/src/proxy/middleware/request/preprocessors/apply-quota-limits.ts similarity index 81% rename from src/proxy/middleware/request/apply-quota-limits.ts rename to src/proxy/middleware/request/preprocessors/apply-quota-limits.ts index e7a637b..a0f163e 100644 --- a/src/proxy/middleware/request/apply-quota-limits.ts +++ b/src/proxy/middleware/request/preprocessors/apply-quota-limits.ts @@ -1,6 +1,6 @@ -import { hasAvailableQuota } from "../../../shared/users/user-store"; -import { isImageGenerationRequest, isTextGenerationRequest } from "../common"; -import { ProxyRequestMiddleware } from "."; +import { hasAvailableQuota } from "../../../../shared/users/user-store"; +import { isImageGenerationRequest, isTextGenerationRequest } from "../../common"; +import { HPMRequestCallback } from "../index"; export class QuotaExceededError extends Error { public quotaInfo: any; @@ -11,7 +11,7 @@ export class QuotaExceededError extends Error { } } -export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => { +export const applyQuotaLimits: HPMRequestCallback = (_proxyReq, req) => { const subjectToQuota = isTextGenerationRequest(req) || isImageGenerationRequest(req); if (!subjectToQuota || !req.user) return; diff --git a/src/proxy/middleware/request/count-prompt-tokens.ts b/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts similarity index 90% rename from src/proxy/middleware/request/count-prompt-tokens.ts rename to src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts index e9072f1..bb798ce 100644 --- a/src/proxy/middleware/request/count-prompt-tokens.ts +++ b/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts @@ -1,6 +1,6 @@ -import { RequestPreprocessor } from "./index"; -import { countTokens } from "../../../shared/tokenization"; -import { assertNever } from "../../../shared/utils"; +import { RequestPreprocessor } from "../index"; +import { countTokens } from "../../../../shared/tokenization"; +import { assertNever } from "../../../../shared/utils"; import type { OpenAIChatMessage } from "./transform-outbound-payload"; /** diff --git a/src/proxy/middleware/request/language-filter.ts b/src/proxy/middleware/request/preprocessors/language-filter.ts similarity index 90% rename from src/proxy/middleware/request/language-filter.ts rename to src/proxy/middleware/request/preprocessors/language-filter.ts index 64f67dc..f206fa9 100644 --- a/src/proxy/middleware/request/language-filter.ts +++ b/src/proxy/middleware/request/preprocessors/language-filter.ts @@ -1,8 +1,8 @@ import { Request } from "express"; -import { config } from "../../../config"; -import { assertNever } from "../../../shared/utils"; -import { RequestPreprocessor } from "."; -import { UserInputError } from "../../../shared/errors"; +import { config } from "../../../../config"; +import { assertNever } from "../../../../shared/utils"; +import { RequestPreprocessor } from "../index"; +import { UserInputError } from "../../../../shared/errors"; import { OpenAIChatMessage } from "./transform-outbound-payload"; const rejectedClients = new Map(); diff --git a/src/proxy/middleware/request/set-api-format.ts b/src/proxy/middleware/request/preprocessors/set-api-format.ts similarity index 73% rename from src/proxy/middleware/request/set-api-format.ts rename to src/proxy/middleware/request/preprocessors/set-api-format.ts index a0bd591..5b2d2c7 100644 --- a/src/proxy/middleware/request/set-api-format.ts +++ b/src/proxy/middleware/request/preprocessors/set-api-format.ts @@ -1,6 +1,6 @@ import { Request } from "express"; -import { APIFormat, LLMService } from "../../../shared/key-management"; -import { RequestPreprocessor } from "."; +import { APIFormat, LLMService } from "../../../../shared/key-management"; +import { RequestPreprocessor } from "../index"; export const setApiFormat = (api: { inApi: Request["inboundApi"]; diff --git a/src/proxy/middleware/request/sign-aws-request.ts b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts similarity index 96% rename from src/proxy/middleware/request/sign-aws-request.ts rename to src/proxy/middleware/request/preprocessors/sign-aws-request.ts index 17847dd..c09cb5b 100644 --- a/src/proxy/middleware/request/sign-aws-request.ts +++ b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts @@ -2,8 +2,8 @@ import express from "express"; import { Sha256 } from "@aws-crypto/sha256-js"; import { SignatureV4 } from "@smithy/signature-v4"; import { HttpRequest } from "@smithy/protocol-http"; -import { keyPool } from "../../../shared/key-management"; -import { RequestPreprocessor } from "."; +import { keyPool } from "../../../../shared/key-management"; +import { RequestPreprocessor } from "../index"; import { AnthropicV1CompleteSchema } from "./transform-outbound-payload"; const AMZ_HOST = diff --git a/src/proxy/middleware/request/transform-outbound-payload.ts b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts similarity index 98% rename from src/proxy/middleware/request/transform-outbound-payload.ts rename to src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts index 23d042d..bec58ac 100644 --- a/src/proxy/middleware/request/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts @@ -1,9 +1,9 @@ import { Request } from "express"; import { z } from "zod"; -import { config } from "../../../config"; -import { isTextGenerationRequest, isImageGenerationRequest } from "../common"; -import { RequestPreprocessor } from "."; -import { APIFormat } from "../../../shared/key-management"; +import { config } from "../../../../config"; +import { isTextGenerationRequest, isImageGenerationRequest } from "../../common"; +import { RequestPreprocessor } from "../index"; +import { APIFormat } from "../../../../shared/key-management"; const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic; const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI; diff --git a/src/proxy/middleware/request/validate-context-size.ts b/src/proxy/middleware/request/preprocessors/validate-context-size.ts similarity index 95% rename from src/proxy/middleware/request/validate-context-size.ts rename to src/proxy/middleware/request/preprocessors/validate-context-size.ts index 3d3cb4b..5489d4a 100644 --- a/src/proxy/middleware/request/validate-context-size.ts +++ b/src/proxy/middleware/request/preprocessors/validate-context-size.ts @@ -1,8 +1,8 @@ import { Request } from "express"; import { z } from "zod"; -import { config } from "../../../config"; -import { assertNever } from "../../../shared/utils"; -import { RequestPreprocessor } from "."; +import { config } from "../../../../config"; +import { assertNever } from "../../../../shared/utils"; +import { RequestPreprocessor } from "../index"; const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic; const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI; diff --git a/src/proxy/middleware/request/rewrite.ts b/src/proxy/middleware/request/rewrite.ts deleted file mode 100644 index 8cc078d..0000000 --- a/src/proxy/middleware/request/rewrite.ts +++ /dev/null @@ -1,42 +0,0 @@ -import { Request } from "express"; -import { ClientRequest } from "http"; -import httpProxy from "http-proxy"; -import { ProxyRequestMiddleware } from "./index"; - -type ProxyReqCallback = httpProxy.ProxyReqCallback; -type RewriterOptions = { - beforeRewrite?: ProxyReqCallback[]; - pipeline: ProxyRequestMiddleware[]; -}; - -export const createOnProxyReqHandler = ({ - beforeRewrite = [], - pipeline, -}: RewriterOptions): ProxyReqCallback => { - return (proxyReq, req, res, options) => { - // The streaming flag must be set before any other middleware runs, because - // it may influence which other middleware a particular API pipeline wants - // to run. - // Image generation requests can't be streamed. - req.isStreaming = req.body.stream === true || req.body.stream === "true"; - req.body.stream = req.isStreaming; - - try { - for (const validator of beforeRewrite) { - validator(proxyReq, req, res, options); - } - } catch (error) { - req.log.error(error, "Error while executing proxy request validator"); - proxyReq.destroy(error); - } - - try { - for (const rewriter of pipeline) { - rewriter(proxyReq, req, res, options); - } - } catch (error) { - req.log.error(error, "Error while executing proxy request rewriter"); - proxyReq.destroy(error); - } - }; -}; diff --git a/src/proxy/middleware/response/log-prompt.ts b/src/proxy/middleware/response/log-prompt.ts index 17b02e6..ba86f72 100644 --- a/src/proxy/middleware/response/log-prompt.ts +++ b/src/proxy/middleware/response/log-prompt.ts @@ -9,7 +9,7 @@ import { } from "../common"; import { ProxyResHandlerWithBody } from "."; import { assertNever } from "../../../shared/utils"; -import { OpenAIChatMessage } from "../request/transform-outbound-payload"; +import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload"; /** If prompt logging is enabled, enqueues the prompt for logging. */ export const logPrompt: ProxyResHandlerWithBody = async ( diff --git a/src/proxy/openai-image.ts b/src/proxy/openai-image.ts index 2c5a63c..fd87b2d 100644 --- a/src/proxy/openai-image.ts +++ b/src/proxy/openai-image.ts @@ -7,11 +7,8 @@ import { ipLimiter } from "./rate-limit"; import { handleProxyError } from "./middleware/common"; import { addKey, - applyQuotaLimits, - blockZoomerOrigins, createPreprocessorMiddleware, finalizeBody, - stripHeaders, createOnProxyReqHandler, } from "./middleware/request"; import { @@ -113,15 +110,7 @@ const openaiImagesProxy = createQueueMiddleware({ "^/v1/chat/completions": "/v1/images/generations", }, on: { - proxyReq: createOnProxyReqHandler({ - pipeline: [ - applyQuotaLimits, - addKey, - blockZoomerOrigins, - stripHeaders, - finalizeBody, - ], - }), + proxyReq: createOnProxyReqHandler({ pipeline: [addKey, finalizeBody] }), proxyRes: createOnProxyResHandler([openaiImagesResponseHandler]), error: handleProxyError, }, diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 6617874..499422b 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -2,7 +2,11 @@ import { RequestHandler, Router } from "express"; import { createProxyMiddleware } from "http-proxy-middleware"; import { config } from "../config"; import { keyPool } from "../shared/key-management"; -import { getOpenAIModelFamily, ModelFamily, OpenAIModelFamily } from "../shared/models"; +import { + getOpenAIModelFamily, + ModelFamily, + OpenAIModelFamily, +} from "../shared/models"; import { logger } from "../logger"; import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit"; @@ -10,18 +14,17 @@ import { handleProxyError } from "./middleware/common"; import { addKey, addKeyForEmbeddingsRequest, - applyQuotaLimits, - blockZoomerOrigins, createEmbeddingsPreprocessorMiddleware, createOnProxyReqHandler, createPreprocessorMiddleware, finalizeBody, forceModel, - limitCompletions, RequestPreprocessor, - stripHeaders, } from "./middleware/request"; -import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response"; +import { + createOnProxyResHandler, + ProxyResHandlerWithBody, +} from "./middleware/response"; // https://platform.openai.com/docs/models/overview export const KNOWN_OPENAI_MODELS = [ @@ -159,14 +162,7 @@ const openaiProxy = createQueueMiddleware({ logger, on: { proxyReq: createOnProxyReqHandler({ - pipeline: [ - applyQuotaLimits, - addKey, - limitCompletions, - blockZoomerOrigins, - stripHeaders, - finalizeBody, - ], + pipeline: [addKey, finalizeBody], }), proxyRes: createOnProxyResHandler([openaiResponseHandler]), error: handleProxyError, @@ -181,7 +177,7 @@ const openaiEmbeddingsProxy = createProxyMiddleware({ logger, on: { proxyReq: createOnProxyReqHandler({ - pipeline: [addKeyForEmbeddingsRequest, stripHeaders, finalizeBody], + pipeline: [addKeyForEmbeddingsRequest, finalizeBody], }), error: handleProxyError, }, diff --git a/src/proxy/palm.ts b/src/proxy/palm.ts index 0137fd3..979411a 100644 --- a/src/proxy/palm.ts +++ b/src/proxy/palm.ts @@ -9,13 +9,10 @@ import { ipLimiter } from "./rate-limit"; import { handleProxyError } from "./middleware/common"; import { addKey, - applyQuotaLimits, - blockZoomerOrigins, createOnProxyReqHandler, createPreprocessorMiddleware, finalizeBody, forceModel, - stripHeaders, } from "./middleware/request"; import { createOnProxyResHandler, @@ -149,14 +146,7 @@ const googlePalmProxy = createQueueMiddleware({ logger, on: { proxyReq: createOnProxyReqHandler({ - beforeRewrite: [reassignPathForPalmModel], - pipeline: [ - applyQuotaLimits, - addKey, - blockZoomerOrigins, - stripHeaders, - finalizeBody, - ], + pipeline: [reassignPathForPalmModel, addKey, finalizeBody], }), proxyRes: createOnProxyResHandler([palmResponseHandler]), error: handleProxyError, diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index abe128c..81d5d58 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -14,17 +14,8 @@ import crypto from "crypto"; import type { Handler, Request } from "express"; import { keyPool } from "../shared/key-management"; -import { - getAwsBedrockModelFamily, - getAzureOpenAIModelFamily, - getClaudeModelFamily, - getGooglePalmModelFamily, - getOpenAIModelFamily, - MODEL_FAMILIES, - ModelFamily, -} from "../shared/models"; +import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models"; import { buildFakeSse, initializeSseStream } from "../shared/streaming"; -import { assertNever } from "../shared/utils"; import { logger } from "../logger"; import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit"; import { RequestPreprocessor } from "./middleware/request"; @@ -132,34 +123,9 @@ export function enqueue(req: Request) { } } -function getPartitionForRequest(req: Request): ModelFamily { - // There is a single request queue, but it is partitioned by model family. - // Model families are typically separated on cost/rate limit boundaries so - // they should be treated as separate queues. - const model = req.body.model ?? "gpt-3.5-turbo"; - - // Weird special case for AWS/Azure because they serve multiple models from - // different vendors, even if currently only one is supported. - if (req.service === "aws") return getAwsBedrockModelFamily(model); - if (req.service === "azure") return getAzureOpenAIModelFamily(model); - - switch (req.outboundApi) { - case "anthropic": - return getClaudeModelFamily(model); - case "openai": - case "openai-text": - case "openai-image": - return getOpenAIModelFamily(model); - case "google-palm": - return getGooglePalmModelFamily(model); - default: - assertNever(req.outboundApi); - } -} - function getQueueForPartition(partition: ModelFamily): Request[] { return queue - .filter((req) => getPartitionForRequest(req) === partition) + .filter((req) => getModelFamilyForRequest(req) === partition) .sort((a, b) => { // Certain requests are exempted from IP-based rate limiting because they // come from a shared IP address. To prevent these requests from starving @@ -222,7 +188,7 @@ function processQueue() { reqs.filter(Boolean).forEach((req) => { if (req?.proceed) { - const modelFamily = getPartitionForRequest(req!); + const modelFamily = getModelFamilyForRequest(req!); req.log.info({ retries: req.retryCount, partition: modelFamily, @@ -279,7 +245,7 @@ let waitTimes: { /** Adds a successful request to the list of wait times. */ export function trackWaitTime(req: Request) { waitTimes.push({ - partition: getPartitionForRequest(req), + partition: getModelFamilyForRequest(req), start: req.startTime!, end: req.queueOutTime ?? Date.now(), isDeprioritized: isFromSharedIp(req), @@ -324,7 +290,7 @@ function calculateWaitTime(partition: ModelFamily) { const currentWaits = queue .filter((req) => { - const isSamePartition = getPartitionForRequest(req) === partition; + const isSamePartition = getModelFamilyForRequest(req) === partition; const isNormalPriority = !isFromSharedIp(req); return isSamePartition && isNormalPriority; }) diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 214ab1e..c61f767 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -170,12 +170,6 @@ export class OpenAIKeyProvider implements KeyProvider { throw new Error(`No keys available for model family '${neededFamily}'.`); } - if (!config.allowedModelFamilies.includes(neededFamily)) { - throw new Error( - `Proxy operator has disabled model family '${neededFamily}'.` - ); - } - // Select a key, from highest priority to lowest priority: // 1. Keys which are not rate limited // a. We ignore rate limits from >30 seconds ago diff --git a/src/shared/models.ts b/src/shared/models.ts index c91ab03..5444529 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -1,6 +1,8 @@ // Don't import anything here, this is imported by config.ts import pino from "pino"; +import type { Request } from "express"; +import { assertNever } from "./utils"; export type OpenAIModelFamily = | "turbo" @@ -103,3 +105,38 @@ export function assertIsKnownModelFamily( throw new Error(`Unknown model family: ${modelFamily}`); } } + +export function getModelFamilyForRequest(req: Request): ModelFamily { + if (req.modelFamily) return req.modelFamily; + // There is a single request queue, but it is partitioned by model family. + // Model families are typically separated on cost/rate limit boundaries so + // they should be treated as separate queues. + const model = req.body.model ?? "gpt-3.5-turbo"; + let modelFamily: ModelFamily; + + // Weird special case for AWS/Azure because they serve multiple models from + // different vendors, even if currently only one is supported. + if (req.service === "aws") { + modelFamily = getAwsBedrockModelFamily(model); + } else if (req.service === "azure") { + modelFamily = getAzureOpenAIModelFamily(model); + } else { + switch (req.outboundApi) { + case "anthropic": + modelFamily = getClaudeModelFamily(model); + break; + case "openai": + case "openai-text": + case "openai-image": + modelFamily = getOpenAIModelFamily(model); + break; + case "google-palm": + modelFamily = getGooglePalmModelFamily(model); + break; + default: + assertNever(req.outboundApi); + } + } + + return (req.modelFamily = modelFamily); +} diff --git a/src/shared/tokenization/openai.ts b/src/shared/tokenization/openai.ts index eeb5575..a162683 100644 --- a/src/shared/tokenization/openai.ts +++ b/src/shared/tokenization/openai.ts @@ -2,7 +2,7 @@ import { Tiktoken } from "tiktoken/lite"; import cl100k_base from "tiktoken/encoders/cl100k_base.json"; import { logger } from "../../logger"; import { libSharp } from "../file-storage"; -import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload"; +import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload"; const log = logger.child({ module: "tokenizer", service: "openai" }); const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170; diff --git a/src/shared/tokenization/tokenizer.ts b/src/shared/tokenization/tokenizer.ts index 1c09930..075a2d1 100644 --- a/src/shared/tokenization/tokenizer.ts +++ b/src/shared/tokenization/tokenizer.ts @@ -1,5 +1,5 @@ import { Request } from "express"; -import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload"; +import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload"; import { assertNever } from "../utils"; import { init as initClaude, diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index 546b55e..6876ad7 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -2,6 +2,7 @@ import type { HttpRequest } from "@smithy/types"; import { Express } from "express-serve-static-core"; import { APIFormat, Key, LLMService } from "../shared/key-management"; import { User } from "../shared/users/schema"; +import { ModelFamily } from "../shared/models"; declare global { namespace Express { @@ -27,6 +28,7 @@ declare global { outputTokens?: number; tokenizerInfo: Record; signedRequest: HttpRequest; + modelFamily?: ModelFamily; } } }