properly enforce allowedModelFamilies; refactor HPM proxyReq handlers

This commit is contained in:
nai-degen 2023-12-05 21:41:04 -06:00
parent 12276a1f59
commit 94d4efe9bb
34 changed files with 204 additions and 262 deletions

View File

@ -7,12 +7,9 @@ import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
applyQuotaLimits,
addAnthropicPreamble,
blockZoomerOrigins,
createPreprocessorMiddleware,
finalizeBody,
stripHeaders,
createOnProxyReqHandler,
} from "./middleware/request";
import {
@ -137,14 +134,7 @@ const anthropicProxy = createQueueMiddleware({
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
addKey,
addAnthropicPreamble,
blockZoomerOrigins,
stripHeaders,
finalizeBody,
],
pipeline: [addKey, addAnthropicPreamble, finalizeBody],
}),
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
error: handleProxyError,

View File

@ -7,19 +7,18 @@ import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
applyQuotaLimits,
createPreprocessorMiddleware,
stripHeaders,
signAwsRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
blockZoomerOrigins,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
const LATEST_AWS_V2_MINOR_VERSION = "1";
let modelsCache: any = null;
let modelsCacheTime = 0;
@ -133,14 +132,7 @@ const awsProxy = createQueueMiddleware({
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
blockZoomerOrigins,
stripHeaders,
finalizeSignedRequest,
],
}),
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsResponseHandler]),
error: handleProxyError,
},
@ -178,20 +170,15 @@ awsRouter.post(
* - frontends sending OpenAI model names because they expect the proxy to
* translate them
*/
const LATEST_AWS_V2_MINOR_VERSION = '1';
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If the string already includes "anthropic.claude", return it unmodified
// If client already specified an AWS Claude model ID, use it
if (model.includes("anthropic.claude")) {
return;
}
// Define a regular expression pattern to match the Claude version strings
const pattern = /^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?$/i;
// Execute the pattern on the model string
const match = model.match(pattern);
// If there's no match, return the latest v2 model
@ -200,34 +187,30 @@ function maybeReassignModel(req: Request) {
return;
}
// Extract parts of the version string
const [, , instant, v, major, , minor] = match;
const [, , instant, , major, , minor] = match;
// If 'instant' is part of the version, return the fixed instant model string
if (instant) {
req.body.model = 'anthropic.claude-instant-v1';
req.body.model = "anthropic.claude-instant-v1";
return;
}
// If the major version is '1', return the fixed v1 model string
if (major === '1') {
req.body.model = 'anthropic.claude-v1';
// There's only one v1 model
if (major === "1") {
req.body.model = "anthropic.claude-v1";
return;
}
// If the major version is '2'
if (major === '2') {
// If the minor version is explicitly '0', return "anthropic.claude-v2" which is claude-2.0
if (minor === '0') {
req.body.model = 'anthropic.claude-v2';
// Try to map Anthropic API v2 models to AWS v2 models
if (major === "2") {
if (minor === "0") {
req.body.model = "anthropic.claude-v2";
return;
}
// Otherwise, return the v2 model string with the latest minor version
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
return;
}
// If none of the above conditions are met, return the latest v2 model by default
// Fallback to latest v2 model
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
return;
}

View File

@ -13,19 +13,15 @@ import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
applyQuotaLimits,
blockZoomerOrigins,
addAzureKey,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
limitCompletions,
stripHeaders,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { addAzureKey } from "./middleware/request/add-azure-key";
let modelsCache: any = null;
let modelsCacheTime = 0;
@ -109,15 +105,7 @@ const azureOpenAIProxy = createQueueMiddleware({
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
limitCompletions,
blockZoomerOrigins,
stripHeaders,
finalizeSignedRequest,
],
}),
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
error: handleProxyError,
},

View File

@ -4,7 +4,7 @@ import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error";
import { buildFakeSse } from "../../shared/streaming";
import { assertNever } from "../../shared/utils";
import { QuotaExceededError } from "./request/apply-quota-limits";
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions";

View File

@ -2,29 +2,30 @@ import type { Request } from "express";
import type { ClientRequest } from "http";
import type { ProxyReqCallback } from "http-proxy";
export { createOnProxyReqHandler } from "./rewrite";
export { createOnProxyReqHandler } from "./onproxyreq-factory";
export {
createPreprocessorMiddleware,
createEmbeddingsPreprocessorMiddleware,
} from "./preprocess";
} from "./preprocessor-factory";
// Express middleware (runs before http-proxy-middleware, can be async)
export { applyQuotaLimits } from "./apply-quota-limits";
export { validateContextSize } from "./validate-context-size";
export { countPromptTokens } from "./count-prompt-tokens";
export { languageFilter } from "./language-filter";
export { setApiFormat } from "./set-api-format";
export { signAwsRequest } from "./sign-aws-request";
export { transformOutboundPayload } from "./transform-outbound-payload";
export { addAzureKey } from "./preprocessors/add-azure-key";
export { applyQuotaLimits } from "./preprocessors/apply-quota-limits";
export { validateContextSize } from "./preprocessors/validate-context-size";
export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
export { languageFilter } from "./preprocessors/language-filter";
export { setApiFormat } from "./preprocessors/set-api-format";
export { signAwsRequest } from "./preprocessors/sign-aws-request";
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
// HPM middleware (runs on onProxyReq, cannot be async)
export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
export { addAnthropicPreamble } from "./add-anthropic-preamble";
export { blockZoomerOrigins } from "./block-zoomer-origins";
export { finalizeBody } from "./finalize-body";
export { finalizeSignedRequest } from "./finalize-signed-request";
export { limitCompletions } from "./limit-completions";
export { stripHeaders } from "./strip-headers";
// http-proxy-middleware callbacks (runs on onProxyReq, cannot be async)
export { addKey, addKeyForEmbeddingsRequest } from "./onproxyreq/add-key";
export { addAnthropicPreamble } from "./onproxyreq/add-anthropic-preamble";
export { blockZoomerOrigins } from "./onproxyreq/block-zoomer-origins";
export { checkModelFamily } from "./onproxyreq/check-model-family";
export { finalizeBody } from "./onproxyreq/finalize-body";
export { finalizeSignedRequest } from "./onproxyreq/finalize-signed-request";
export { stripHeaders } from "./onproxyreq/strip-headers";
/**
* Middleware that runs prior to the request being handled by http-proxy-
@ -43,7 +44,7 @@ export { stripHeaders } from "./strip-headers";
export type RequestPreprocessor = (req: Request) => void | Promise<void>;
/**
* Middleware that runs immediately before the request is sent to the API in
* Callbacks that run immediately before the request is sent to the API in
* response to http-proxy-middleware's `proxyReq` event.
*
* Async functions cannot be used here as HPM's event emitter is not async and
@ -53,7 +54,7 @@ export type RequestPreprocessor = (req: Request) => void | Promise<void>;
* first attempt is rate limited and the request is automatically retried by the
* request queue middleware.
*/
export type ProxyRequestMiddleware = ProxyReqCallback<ClientRequest, Request>;
export type HPMRequestCallback = ProxyReqCallback<ClientRequest, Request>;
export const forceModel = (model: string) => (req: Request) =>
void (req.body.model = model);

View File

@ -1,16 +0,0 @@
import { isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
/**
* Don't allow multiple text completions to be requested to prevent abuse.
* OpenAI-only, Anthropic provides no such parameter.
**/
export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => {
if (isTextGenerationRequest(req) && req.outboundApi === "openai") {
const originalN = req.body?.n || 1;
req.body.n = 1;
if (originalN !== req.body.n) {
req.log.warn(`Limiting completion choices from ${originalN} to 1`);
}
}
};

View File

@ -0,0 +1,43 @@
import {
applyQuotaLimits,
blockZoomerOrigins,
checkModelFamily,
HPMRequestCallback,
stripHeaders,
} from "./index";
type ProxyReqHandlerFactoryOptions = { pipeline: HPMRequestCallback[] };
/**
* Returns an http-proxy-middleware request handler that runs the given set of
* onProxyReq callback functions in sequence.
*
* These will run each time a request is proxied, including on automatic retries
* by the queue after encountering a rate limit.
*/
export const createOnProxyReqHandler = ({
pipeline,
}: ProxyReqHandlerFactoryOptions): HPMRequestCallback => {
const callbackPipeline = [
checkModelFamily,
applyQuotaLimits,
blockZoomerOrigins,
stripHeaders,
...pipeline,
];
return (proxyReq, req, res, options) => {
// The streaming flag must be set before any other onProxyReq handler runs,
// as it may influence the behavior of subsequent handlers.
// Image generation requests can't be streamed.
req.isStreaming = req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming;
try {
for (const fn of callbackPipeline) {
fn(proxyReq, req, res, options);
}
} catch (error) {
proxyReq.destroy(error);
}
};
};

View File

@ -1,13 +1,13 @@
import { AnthropicKey, Key } from "../../../shared/key-management";
import { isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
import { AnthropicKey, Key } from "../../../../shared/key-management";
import { isTextGenerationRequest } from "../../common";
import { HPMRequestCallback } from "../index";
/**
* Some keys require the prompt to start with `\n\nHuman:`. There is no way to
* know this without trying to send the request and seeing if it fails. If a
* key is marked as requiring a preamble, it will be added here.
*/
export const addAnthropicPreamble: ProxyRequestMiddleware = (
export const addAnthropicPreamble: HPMRequestCallback = (
_proxyReq,
req
) => {

View File

@ -1,10 +1,10 @@
import { Key, OpenAIKey, keyPool } from "../../../shared/key-management";
import { isEmbeddingsRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
import { assertNever } from "../../../shared/utils";
import { Key, OpenAIKey, keyPool } from "../../../../shared/key-management";
import { isEmbeddingsRequest } from "../../common";
import { HPMRequestCallback } from "../index";
import { assertNever } from "../../../../shared/utils";
/** Add a key that can service this request to the request object. */
export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
export const addKey: HPMRequestCallback = (proxyReq, req) => {
let assignedKey: Key;
if (!req.inboundApi || !req.outboundApi) {
@ -97,7 +97,7 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
* Special case for embeddings requests which don't go through the normal
* request pipeline.
*/
export const addKeyForEmbeddingsRequest: ProxyRequestMiddleware = (
export const addKeyForEmbeddingsRequest: HPMRequestCallback = (
proxyReq,
req
) => {

View File

@ -1,4 +1,4 @@
import { ProxyRequestMiddleware } from ".";
import { HPMRequestCallback } from "../index";
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
@ -13,7 +13,7 @@ class ForbiddenError extends Error {
* Blocks requests from Janitor AI users with a fake, scary error message so I
* stop getting emails asking for tech support.
*/
export const blockZoomerOrigins: ProxyRequestMiddleware = (_proxyReq, req) => {
export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => {
const origin = req.headers.origin || req.headers.referer;
if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) {
// Venus-derivatives send a test prompt to check if the proxy is working.

View File

@ -0,0 +1,13 @@
import { HPMRequestCallback } from "../index";
import { config } from "../../../../config";
import { getModelFamilyForRequest } from "../../../../shared/models";
/**
* Ensures the selected model family is enabled by the proxy configuration.
**/
export const checkModelFamily: HPMRequestCallback = (proxyReq, req) => {
const family = getModelFamilyForRequest(req);
if (!config.allowedModelFamilies.includes(family)) {
throw new Error(`Model family ${family} is not permitted on this proxy`);
}
};

View File

@ -1,8 +1,8 @@
import { fixRequestBody } from "http-proxy-middleware";
import type { ProxyRequestMiddleware } from ".";
import type { HPMRequestCallback } from "../index";
/** Finalize the rewritten request body. Must be the last rewriter. */
export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => {
export const finalizeBody: HPMRequestCallback = (proxyReq, req) => {
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
// For image generation requests, remove stream flag.
if (req.outboundApi === "openai-image") {

View File

@ -1,11 +1,11 @@
import type { ProxyRequestMiddleware } from ".";
import type { HPMRequestCallback } from "../index";
/**
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
* before the proxy middleware. This function just assigns the path and headers
* to the proxy request.
*/
export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => {
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
if (!req.signedRequest) {
throw new Error("Expected req.signedRequest to be set");
}

View File

@ -1,10 +1,10 @@
import { ProxyRequestMiddleware } from ".";
import { HPMRequestCallback } from "../index";
/**
* Removes origin and referer headers before sending the request to the API for
* privacy reasons.
**/
export const stripHeaders: ProxyRequestMiddleware = (proxyReq) => {
export const stripHeaders: HPMRequestCallback = (proxyReq) => {
proxyReq.setHeader("origin", "");
proxyReq.setHeader("referer", "");

View File

@ -29,6 +29,14 @@ type RequestPreprocessorOptions = {
/**
* Returns a middleware function that processes the request body into the given
* API format, and then sequentially runs the given additional preprocessors.
*
* These run first in the request lifecycle, a single time per request before it
* is added to the request queue. They aren't run again if the request is
* re-attempted after a rate limit.
*
* To run a preprocessor on every re-attempt, pass it to createQueueMiddleware.
* It will run after these preprocessors, but before the request is sent to
* http-proxy-middleware.
*/
export const createPreprocessorMiddleware = (
apiFormat: Parameters<typeof setApiFormat>[0],

View File

@ -1,5 +1,5 @@
import { AzureOpenAIKey, keyPool } from "../../../shared/key-management";
import { RequestPreprocessor } from ".";
import { AzureOpenAIKey, keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
export const addAzureKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";

View File

@ -1,6 +1,6 @@
import { hasAvailableQuota } from "../../../shared/users/user-store";
import { isImageGenerationRequest, isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
import { hasAvailableQuota } from "../../../../shared/users/user-store";
import { isImageGenerationRequest, isTextGenerationRequest } from "../../common";
import { HPMRequestCallback } from "../index";
export class QuotaExceededError extends Error {
public quotaInfo: any;
@ -11,7 +11,7 @@ export class QuotaExceededError extends Error {
}
}
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
export const applyQuotaLimits: HPMRequestCallback = (_proxyReq, req) => {
const subjectToQuota =
isTextGenerationRequest(req) || isImageGenerationRequest(req);
if (!subjectToQuota || !req.user) return;

View File

@ -1,6 +1,6 @@
import { RequestPreprocessor } from "./index";
import { countTokens } from "../../../shared/tokenization";
import { assertNever } from "../../../shared/utils";
import { RequestPreprocessor } from "../index";
import { countTokens } from "../../../../shared/tokenization";
import { assertNever } from "../../../../shared/utils";
import type { OpenAIChatMessage } from "./transform-outbound-payload";
/**

View File

@ -1,8 +1,8 @@
import { Request } from "express";
import { config } from "../../../config";
import { assertNever } from "../../../shared/utils";
import { RequestPreprocessor } from ".";
import { UserInputError } from "../../../shared/errors";
import { config } from "../../../../config";
import { assertNever } from "../../../../shared/utils";
import { RequestPreprocessor } from "../index";
import { UserInputError } from "../../../../shared/errors";
import { OpenAIChatMessage } from "./transform-outbound-payload";
const rejectedClients = new Map<string, number>();

View File

@ -1,6 +1,6 @@
import { Request } from "express";
import { APIFormat, LLMService } from "../../../shared/key-management";
import { RequestPreprocessor } from ".";
import { APIFormat, LLMService } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
export const setApiFormat = (api: {
inApi: Request["inboundApi"];

View File

@ -2,8 +2,8 @@ import express from "express";
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import { keyPool } from "../../../shared/key-management";
import { RequestPreprocessor } from ".";
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import { AnthropicV1CompleteSchema } from "./transform-outbound-payload";
const AMZ_HOST =

View File

@ -1,9 +1,9 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
import { RequestPreprocessor } from ".";
import { APIFormat } from "../../../shared/key-management";
import { config } from "../../../../config";
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common";
import { RequestPreprocessor } from "../index";
import { APIFormat } from "../../../../shared/key-management";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;

View File

@ -1,8 +1,8 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { assertNever } from "../../../shared/utils";
import { RequestPreprocessor } from ".";
import { config } from "../../../../config";
import { assertNever } from "../../../../shared/utils";
import { RequestPreprocessor } from "../index";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;

View File

@ -1,42 +0,0 @@
import { Request } from "express";
import { ClientRequest } from "http";
import httpProxy from "http-proxy";
import { ProxyRequestMiddleware } from "./index";
type ProxyReqCallback = httpProxy.ProxyReqCallback<ClientRequest, Request>;
type RewriterOptions = {
beforeRewrite?: ProxyReqCallback[];
pipeline: ProxyRequestMiddleware[];
};
export const createOnProxyReqHandler = ({
beforeRewrite = [],
pipeline,
}: RewriterOptions): ProxyReqCallback => {
return (proxyReq, req, res, options) => {
// The streaming flag must be set before any other middleware runs, because
// it may influence which other middleware a particular API pipeline wants
// to run.
// Image generation requests can't be streamed.
req.isStreaming = req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming;
try {
for (const validator of beforeRewrite) {
validator(proxyReq, req, res, options);
}
} catch (error) {
req.log.error(error, "Error while executing proxy request validator");
proxyReq.destroy(error);
}
try {
for (const rewriter of pipeline) {
rewriter(proxyReq, req, res, options);
}
} catch (error) {
req.log.error(error, "Error while executing proxy request rewriter");
proxyReq.destroy(error);
}
};
};

View File

@ -9,7 +9,7 @@ import {
} from "../common";
import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
import { OpenAIChatMessage } from "../request/transform-outbound-payload";
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload";
/** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async (

View File

@ -7,11 +7,8 @@ import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
applyQuotaLimits,
blockZoomerOrigins,
createPreprocessorMiddleware,
finalizeBody,
stripHeaders,
createOnProxyReqHandler,
} from "./middleware/request";
import {
@ -113,15 +110,7 @@ const openaiImagesProxy = createQueueMiddleware({
"^/v1/chat/completions": "/v1/images/generations",
},
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
addKey,
blockZoomerOrigins,
stripHeaders,
finalizeBody,
],
}),
proxyReq: createOnProxyReqHandler({ pipeline: [addKey, finalizeBody] }),
proxyRes: createOnProxyResHandler([openaiImagesResponseHandler]),
error: handleProxyError,
},

View File

@ -2,7 +2,11 @@ import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
import { getOpenAIModelFamily, ModelFamily, OpenAIModelFamily } from "../shared/models";
import {
getOpenAIModelFamily,
ModelFamily,
OpenAIModelFamily,
} from "../shared/models";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
@ -10,18 +14,17 @@ import { handleProxyError } from "./middleware/common";
import {
addKey,
addKeyForEmbeddingsRequest,
applyQuotaLimits,
blockZoomerOrigins,
createEmbeddingsPreprocessorMiddleware,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
forceModel,
limitCompletions,
RequestPreprocessor,
stripHeaders,
} from "./middleware/request";
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
// https://platform.openai.com/docs/models/overview
export const KNOWN_OPENAI_MODELS = [
@ -159,14 +162,7 @@ const openaiProxy = createQueueMiddleware({
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
addKey,
limitCompletions,
blockZoomerOrigins,
stripHeaders,
finalizeBody,
],
pipeline: [addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
error: handleProxyError,
@ -181,7 +177,7 @@ const openaiEmbeddingsProxy = createProxyMiddleware({
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [addKeyForEmbeddingsRequest, stripHeaders, finalizeBody],
pipeline: [addKeyForEmbeddingsRequest, finalizeBody],
}),
error: handleProxyError,
},

View File

@ -9,13 +9,10 @@ import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
applyQuotaLimits,
blockZoomerOrigins,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
forceModel,
stripHeaders,
} from "./middleware/request";
import {
createOnProxyResHandler,
@ -149,14 +146,7 @@ const googlePalmProxy = createQueueMiddleware({
logger,
on: {
proxyReq: createOnProxyReqHandler({
beforeRewrite: [reassignPathForPalmModel],
pipeline: [
applyQuotaLimits,
addKey,
blockZoomerOrigins,
stripHeaders,
finalizeBody,
],
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([palmResponseHandler]),
error: handleProxyError,

View File

@ -14,17 +14,8 @@
import crypto from "crypto";
import type { Handler, Request } from "express";
import { keyPool } from "../shared/key-management";
import {
getAwsBedrockModelFamily,
getAzureOpenAIModelFamily,
getClaudeModelFamily,
getGooglePalmModelFamily,
getOpenAIModelFamily,
MODEL_FAMILIES,
ModelFamily,
} from "../shared/models";
import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models";
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
import { assertNever } from "../shared/utils";
import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
@ -132,34 +123,9 @@ export function enqueue(req: Request) {
}
}
function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model family.
// Model families are typically separated on cost/rate limit boundaries so
// they should be treated as separate queues.
const model = req.body.model ?? "gpt-3.5-turbo";
// Weird special case for AWS/Azure because they serve multiple models from
// different vendors, even if currently only one is supported.
if (req.service === "aws") return getAwsBedrockModelFamily(model);
if (req.service === "azure") return getAzureOpenAIModelFamily(model);
switch (req.outboundApi) {
case "anthropic":
return getClaudeModelFamily(model);
case "openai":
case "openai-text":
case "openai-image":
return getOpenAIModelFamily(model);
case "google-palm":
return getGooglePalmModelFamily(model);
default:
assertNever(req.outboundApi);
}
}
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue
.filter((req) => getPartitionForRequest(req) === partition)
.filter((req) => getModelFamilyForRequest(req) === partition)
.sort((a, b) => {
// Certain requests are exempted from IP-based rate limiting because they
// come from a shared IP address. To prevent these requests from starving
@ -222,7 +188,7 @@ function processQueue() {
reqs.filter(Boolean).forEach((req) => {
if (req?.proceed) {
const modelFamily = getPartitionForRequest(req!);
const modelFamily = getModelFamilyForRequest(req!);
req.log.info({
retries: req.retryCount,
partition: modelFamily,
@ -279,7 +245,7 @@ let waitTimes: {
/** Adds a successful request to the list of wait times. */
export function trackWaitTime(req: Request) {
waitTimes.push({
partition: getPartitionForRequest(req),
partition: getModelFamilyForRequest(req),
start: req.startTime!,
end: req.queueOutTime ?? Date.now(),
isDeprioritized: isFromSharedIp(req),
@ -324,7 +290,7 @@ function calculateWaitTime(partition: ModelFamily) {
const currentWaits = queue
.filter((req) => {
const isSamePartition = getPartitionForRequest(req) === partition;
const isSamePartition = getModelFamilyForRequest(req) === partition;
const isNormalPriority = !isFromSharedIp(req);
return isSamePartition && isNormalPriority;
})

View File

@ -170,12 +170,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
throw new Error(`No keys available for model family '${neededFamily}'.`);
}
if (!config.allowedModelFamilies.includes(neededFamily)) {
throw new Error(
`Proxy operator has disabled model family '${neededFamily}'.`
);
}
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. We ignore rate limits from >30 seconds ago

View File

@ -1,6 +1,8 @@
// Don't import anything here, this is imported by config.ts
import pino from "pino";
import type { Request } from "express";
import { assertNever } from "./utils";
export type OpenAIModelFamily =
| "turbo"
@ -103,3 +105,38 @@ export function assertIsKnownModelFamily(
throw new Error(`Unknown model family: ${modelFamily}`);
}
}
export function getModelFamilyForRequest(req: Request): ModelFamily {
if (req.modelFamily) return req.modelFamily;
// There is a single request queue, but it is partitioned by model family.
// Model families are typically separated on cost/rate limit boundaries so
// they should be treated as separate queues.
const model = req.body.model ?? "gpt-3.5-turbo";
let modelFamily: ModelFamily;
// Weird special case for AWS/Azure because they serve multiple models from
// different vendors, even if currently only one is supported.
if (req.service === "aws") {
modelFamily = getAwsBedrockModelFamily(model);
} else if (req.service === "azure") {
modelFamily = getAzureOpenAIModelFamily(model);
} else {
switch (req.outboundApi) {
case "anthropic":
modelFamily = getClaudeModelFamily(model);
break;
case "openai":
case "openai-text":
case "openai-image":
modelFamily = getOpenAIModelFamily(model);
break;
case "google-palm":
modelFamily = getGooglePalmModelFamily(model);
break;
default:
assertNever(req.outboundApi);
}
}
return (req.modelFamily = modelFamily);
}

View File

@ -2,7 +2,7 @@ import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../logger";
import { libSharp } from "../file-storage";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
const log = logger.child({ module: "tokenizer", service: "openai" });
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;

View File

@ -1,5 +1,5 @@
import { Request } from "express";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { assertNever } from "../utils";
import {
init as initClaude,

View File

@ -2,6 +2,7 @@ import type { HttpRequest } from "@smithy/types";
import { Express } from "express-serve-static-core";
import { APIFormat, Key, LLMService } from "../shared/key-management";
import { User } from "../shared/users/schema";
import { ModelFamily } from "../shared/models";
declare global {
namespace Express {
@ -27,6 +28,7 @@ declare global {
outputTokens?: number;
tokenizerInfo: Record<string, any>;
signedRequest: HttpRequest;
modelFamily?: ModelFamily;
}
}
}