properly enforce allowedModelFamilies; refactor HPM proxyReq handlers
This commit is contained in:
parent
12276a1f59
commit
94d4efe9bb
|
@ -7,12 +7,9 @@ import { ipLimiter } from "./rate-limit";
|
|||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
applyQuotaLimits,
|
||||
addAnthropicPreamble,
|
||||
blockZoomerOrigins,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
stripHeaders,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
|
@ -137,14 +134,7 @@ const anthropicProxy = createQueueMiddleware({
|
|||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
addAnthropicPreamble,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
],
|
||||
pipeline: [addKey, addAnthropicPreamble, finalizeBody],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
|
||||
error: handleProxyError,
|
||||
|
|
|
@ -7,19 +7,18 @@ import { createQueueMiddleware } from "./queue";
|
|||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
applyQuotaLimits,
|
||||
createPreprocessorMiddleware,
|
||||
stripHeaders,
|
||||
signAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
blockZoomerOrigins,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
} from "./middleware/response";
|
||||
|
||||
const LATEST_AWS_V2_MINOR_VERSION = "1";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
|
@ -133,14 +132,7 @@ const awsProxy = createQueueMiddleware({
|
|||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeSignedRequest,
|
||||
],
|
||||
}),
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
|
@ -178,20 +170,15 @@ awsRouter.post(
|
|||
* - frontends sending OpenAI model names because they expect the proxy to
|
||||
* translate them
|
||||
*/
|
||||
const LATEST_AWS_V2_MINOR_VERSION = '1';
|
||||
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If the string already includes "anthropic.claude", return it unmodified
|
||||
// If client already specified an AWS Claude model ID, use it
|
||||
if (model.includes("anthropic.claude")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Define a regular expression pattern to match the Claude version strings
|
||||
const pattern = /^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?$/i;
|
||||
|
||||
// Execute the pattern on the model string
|
||||
const match = model.match(pattern);
|
||||
|
||||
// If there's no match, return the latest v2 model
|
||||
|
@ -200,34 +187,30 @@ function maybeReassignModel(req: Request) {
|
|||
return;
|
||||
}
|
||||
|
||||
// Extract parts of the version string
|
||||
const [, , instant, v, major, , minor] = match;
|
||||
const [, , instant, , major, , minor] = match;
|
||||
|
||||
// If 'instant' is part of the version, return the fixed instant model string
|
||||
if (instant) {
|
||||
req.body.model = 'anthropic.claude-instant-v1';
|
||||
req.body.model = "anthropic.claude-instant-v1";
|
||||
return;
|
||||
}
|
||||
|
||||
// If the major version is '1', return the fixed v1 model string
|
||||
if (major === '1') {
|
||||
req.body.model = 'anthropic.claude-v1';
|
||||
// There's only one v1 model
|
||||
if (major === "1") {
|
||||
req.body.model = "anthropic.claude-v1";
|
||||
return;
|
||||
}
|
||||
|
||||
// If the major version is '2'
|
||||
if (major === '2') {
|
||||
// If the minor version is explicitly '0', return "anthropic.claude-v2" which is claude-2.0
|
||||
if (minor === '0') {
|
||||
req.body.model = 'anthropic.claude-v2';
|
||||
// Try to map Anthropic API v2 models to AWS v2 models
|
||||
if (major === "2") {
|
||||
if (minor === "0") {
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
return;
|
||||
}
|
||||
// Otherwise, return the v2 model string with the latest minor version
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
// If none of the above conditions are met, return the latest v2 model by default
|
||||
// Fallback to latest v2 model
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -13,19 +13,15 @@ import { createQueueMiddleware } from "./queue";
|
|||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
addAzureKey,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
limitCompletions,
|
||||
stripHeaders,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { addAzureKey } from "./middleware/request/add-azure-key";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
@ -109,15 +105,7 @@ const azureOpenAIProxy = createQueueMiddleware({
|
|||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
limitCompletions,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeSignedRequest,
|
||||
],
|
||||
}),
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
|
|
|
@ -4,7 +4,7 @@ import { ZodError } from "zod";
|
|||
import { generateErrorMessage } from "zod-error";
|
||||
import { buildFakeSse } from "../../shared/streaming";
|
||||
import { assertNever } from "../../shared/utils";
|
||||
import { QuotaExceededError } from "./request/apply-quota-limits";
|
||||
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
|
||||
|
||||
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
|
||||
const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions";
|
||||
|
|
|
@ -2,29 +2,30 @@ import type { Request } from "express";
|
|||
import type { ClientRequest } from "http";
|
||||
import type { ProxyReqCallback } from "http-proxy";
|
||||
|
||||
export { createOnProxyReqHandler } from "./rewrite";
|
||||
export { createOnProxyReqHandler } from "./onproxyreq-factory";
|
||||
export {
|
||||
createPreprocessorMiddleware,
|
||||
createEmbeddingsPreprocessorMiddleware,
|
||||
} from "./preprocess";
|
||||
} from "./preprocessor-factory";
|
||||
|
||||
// Express middleware (runs before http-proxy-middleware, can be async)
|
||||
export { applyQuotaLimits } from "./apply-quota-limits";
|
||||
export { validateContextSize } from "./validate-context-size";
|
||||
export { countPromptTokens } from "./count-prompt-tokens";
|
||||
export { languageFilter } from "./language-filter";
|
||||
export { setApiFormat } from "./set-api-format";
|
||||
export { signAwsRequest } from "./sign-aws-request";
|
||||
export { transformOutboundPayload } from "./transform-outbound-payload";
|
||||
export { addAzureKey } from "./preprocessors/add-azure-key";
|
||||
export { applyQuotaLimits } from "./preprocessors/apply-quota-limits";
|
||||
export { validateContextSize } from "./preprocessors/validate-context-size";
|
||||
export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
|
||||
export { languageFilter } from "./preprocessors/language-filter";
|
||||
export { setApiFormat } from "./preprocessors/set-api-format";
|
||||
export { signAwsRequest } from "./preprocessors/sign-aws-request";
|
||||
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
|
||||
|
||||
// HPM middleware (runs on onProxyReq, cannot be async)
|
||||
export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
|
||||
export { addAnthropicPreamble } from "./add-anthropic-preamble";
|
||||
export { blockZoomerOrigins } from "./block-zoomer-origins";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
export { finalizeSignedRequest } from "./finalize-signed-request";
|
||||
export { limitCompletions } from "./limit-completions";
|
||||
export { stripHeaders } from "./strip-headers";
|
||||
// http-proxy-middleware callbacks (runs on onProxyReq, cannot be async)
|
||||
export { addKey, addKeyForEmbeddingsRequest } from "./onproxyreq/add-key";
|
||||
export { addAnthropicPreamble } from "./onproxyreq/add-anthropic-preamble";
|
||||
export { blockZoomerOrigins } from "./onproxyreq/block-zoomer-origins";
|
||||
export { checkModelFamily } from "./onproxyreq/check-model-family";
|
||||
export { finalizeBody } from "./onproxyreq/finalize-body";
|
||||
export { finalizeSignedRequest } from "./onproxyreq/finalize-signed-request";
|
||||
export { stripHeaders } from "./onproxyreq/strip-headers";
|
||||
|
||||
/**
|
||||
* Middleware that runs prior to the request being handled by http-proxy-
|
||||
|
@ -43,7 +44,7 @@ export { stripHeaders } from "./strip-headers";
|
|||
export type RequestPreprocessor = (req: Request) => void | Promise<void>;
|
||||
|
||||
/**
|
||||
* Middleware that runs immediately before the request is sent to the API in
|
||||
* Callbacks that run immediately before the request is sent to the API in
|
||||
* response to http-proxy-middleware's `proxyReq` event.
|
||||
*
|
||||
* Async functions cannot be used here as HPM's event emitter is not async and
|
||||
|
@ -53,7 +54,7 @@ export type RequestPreprocessor = (req: Request) => void | Promise<void>;
|
|||
* first attempt is rate limited and the request is automatically retried by the
|
||||
* request queue middleware.
|
||||
*/
|
||||
export type ProxyRequestMiddleware = ProxyReqCallback<ClientRequest, Request>;
|
||||
export type HPMRequestCallback = ProxyReqCallback<ClientRequest, Request>;
|
||||
|
||||
export const forceModel = (model: string) => (req: Request) =>
|
||||
void (req.body.model = model);
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
import { isTextGenerationRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* Don't allow multiple text completions to be requested to prevent abuse.
|
||||
* OpenAI-only, Anthropic provides no such parameter.
|
||||
**/
|
||||
export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (isTextGenerationRequest(req) && req.outboundApi === "openai") {
|
||||
const originalN = req.body?.n || 1;
|
||||
req.body.n = 1;
|
||||
if (originalN !== req.body.n) {
|
||||
req.log.warn(`Limiting completion choices from ${originalN} to 1`);
|
||||
}
|
||||
}
|
||||
};
|
|
@ -0,0 +1,43 @@
|
|||
import {
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
checkModelFamily,
|
||||
HPMRequestCallback,
|
||||
stripHeaders,
|
||||
} from "./index";
|
||||
|
||||
type ProxyReqHandlerFactoryOptions = { pipeline: HPMRequestCallback[] };
|
||||
|
||||
/**
|
||||
* Returns an http-proxy-middleware request handler that runs the given set of
|
||||
* onProxyReq callback functions in sequence.
|
||||
*
|
||||
* These will run each time a request is proxied, including on automatic retries
|
||||
* by the queue after encountering a rate limit.
|
||||
*/
|
||||
export const createOnProxyReqHandler = ({
|
||||
pipeline,
|
||||
}: ProxyReqHandlerFactoryOptions): HPMRequestCallback => {
|
||||
const callbackPipeline = [
|
||||
checkModelFamily,
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
...pipeline,
|
||||
];
|
||||
return (proxyReq, req, res, options) => {
|
||||
// The streaming flag must be set before any other onProxyReq handler runs,
|
||||
// as it may influence the behavior of subsequent handlers.
|
||||
// Image generation requests can't be streamed.
|
||||
req.isStreaming = req.body.stream === true || req.body.stream === "true";
|
||||
req.body.stream = req.isStreaming;
|
||||
|
||||
try {
|
||||
for (const fn of callbackPipeline) {
|
||||
fn(proxyReq, req, res, options);
|
||||
}
|
||||
} catch (error) {
|
||||
proxyReq.destroy(error);
|
||||
}
|
||||
};
|
||||
};
|
|
@ -1,13 +1,13 @@
|
|||
import { AnthropicKey, Key } from "../../../shared/key-management";
|
||||
import { isTextGenerationRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
import { AnthropicKey, Key } from "../../../../shared/key-management";
|
||||
import { isTextGenerationRequest } from "../../common";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
|
||||
/**
|
||||
* Some keys require the prompt to start with `\n\nHuman:`. There is no way to
|
||||
* know this without trying to send the request and seeing if it fails. If a
|
||||
* key is marked as requiring a preamble, it will be added here.
|
||||
*/
|
||||
export const addAnthropicPreamble: ProxyRequestMiddleware = (
|
||||
export const addAnthropicPreamble: HPMRequestCallback = (
|
||||
_proxyReq,
|
||||
req
|
||||
) => {
|
|
@ -1,10 +1,10 @@
|
|||
import { Key, OpenAIKey, keyPool } from "../../../shared/key-management";
|
||||
import { isEmbeddingsRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { Key, OpenAIKey, keyPool } from "../../../../shared/key-management";
|
||||
import { isEmbeddingsRequest } from "../../common";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
|
||||
/** Add a key that can service this request to the request object. */
|
||||
export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
let assignedKey: Key;
|
||||
|
||||
if (!req.inboundApi || !req.outboundApi) {
|
||||
|
@ -97,7 +97,7 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
|||
* Special case for embeddings requests which don't go through the normal
|
||||
* request pipeline.
|
||||
*/
|
||||
export const addKeyForEmbeddingsRequest: ProxyRequestMiddleware = (
|
||||
export const addKeyForEmbeddingsRequest: HPMRequestCallback = (
|
||||
proxyReq,
|
||||
req
|
||||
) => {
|
|
@ -1,4 +1,4 @@
|
|||
import { ProxyRequestMiddleware } from ".";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
|
||||
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
|
||||
|
||||
|
@ -13,7 +13,7 @@ class ForbiddenError extends Error {
|
|||
* Blocks requests from Janitor AI users with a fake, scary error message so I
|
||||
* stop getting emails asking for tech support.
|
||||
*/
|
||||
export const blockZoomerOrigins: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => {
|
||||
const origin = req.headers.origin || req.headers.referer;
|
||||
if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) {
|
||||
// Venus-derivatives send a test prompt to check if the proxy is working.
|
|
@ -0,0 +1,13 @@
|
|||
import { HPMRequestCallback } from "../index";
|
||||
import { config } from "../../../../config";
|
||||
import { getModelFamilyForRequest } from "../../../../shared/models";
|
||||
|
||||
/**
|
||||
* Ensures the selected model family is enabled by the proxy configuration.
|
||||
**/
|
||||
export const checkModelFamily: HPMRequestCallback = (proxyReq, req) => {
|
||||
const family = getModelFamilyForRequest(req);
|
||||
if (!config.allowedModelFamilies.includes(family)) {
|
||||
throw new Error(`Model family ${family} is not permitted on this proxy`);
|
||||
}
|
||||
};
|
|
@ -1,8 +1,8 @@
|
|||
import { fixRequestBody } from "http-proxy-middleware";
|
||||
import type { ProxyRequestMiddleware } from ".";
|
||||
import type { HPMRequestCallback } from "../index";
|
||||
|
||||
/** Finalize the rewritten request body. Must be the last rewriter. */
|
||||
export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
export const finalizeBody: HPMRequestCallback = (proxyReq, req) => {
|
||||
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
|
||||
// For image generation requests, remove stream flag.
|
||||
if (req.outboundApi === "openai-image") {
|
|
@ -1,11 +1,11 @@
|
|||
import type { ProxyRequestMiddleware } from ".";
|
||||
import type { HPMRequestCallback } from "../index";
|
||||
|
||||
/**
|
||||
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
|
||||
* before the proxy middleware. This function just assigns the path and headers
|
||||
* to the proxy request.
|
||||
*/
|
||||
export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
|
||||
if (!req.signedRequest) {
|
||||
throw new Error("Expected req.signedRequest to be set");
|
||||
}
|
|
@ -1,10 +1,10 @@
|
|||
import { ProxyRequestMiddleware } from ".";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
|
||||
/**
|
||||
* Removes origin and referer headers before sending the request to the API for
|
||||
* privacy reasons.
|
||||
**/
|
||||
export const stripHeaders: ProxyRequestMiddleware = (proxyReq) => {
|
||||
export const stripHeaders: HPMRequestCallback = (proxyReq) => {
|
||||
proxyReq.setHeader("origin", "");
|
||||
proxyReq.setHeader("referer", "");
|
||||
|
|
@ -29,6 +29,14 @@ type RequestPreprocessorOptions = {
|
|||
/**
|
||||
* Returns a middleware function that processes the request body into the given
|
||||
* API format, and then sequentially runs the given additional preprocessors.
|
||||
*
|
||||
* These run first in the request lifecycle, a single time per request before it
|
||||
* is added to the request queue. They aren't run again if the request is
|
||||
* re-attempted after a rate limit.
|
||||
*
|
||||
* To run a preprocessor on every re-attempt, pass it to createQueueMiddleware.
|
||||
* It will run after these preprocessors, but before the request is sent to
|
||||
* http-proxy-middleware.
|
||||
*/
|
||||
export const createPreprocessorMiddleware = (
|
||||
apiFormat: Parameters<typeof setApiFormat>[0],
|
|
@ -1,5 +1,5 @@
|
|||
import { AzureOpenAIKey, keyPool } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { AzureOpenAIKey, keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const addAzureKey: RequestPreprocessor = (req) => {
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
|
|
@ -1,6 +1,6 @@
|
|||
import { hasAvailableQuota } from "../../../shared/users/user-store";
|
||||
import { isImageGenerationRequest, isTextGenerationRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
import { hasAvailableQuota } from "../../../../shared/users/user-store";
|
||||
import { isImageGenerationRequest, isTextGenerationRequest } from "../../common";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
|
||||
export class QuotaExceededError extends Error {
|
||||
public quotaInfo: any;
|
||||
|
@ -11,7 +11,7 @@ export class QuotaExceededError extends Error {
|
|||
}
|
||||
}
|
||||
|
||||
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
export const applyQuotaLimits: HPMRequestCallback = (_proxyReq, req) => {
|
||||
const subjectToQuota =
|
||||
isTextGenerationRequest(req) || isImageGenerationRequest(req);
|
||||
if (!subjectToQuota || !req.user) return;
|
|
@ -1,6 +1,6 @@
|
|||
import { RequestPreprocessor } from "./index";
|
||||
import { countTokens } from "../../../shared/tokenization";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { countTokens } from "../../../../shared/tokenization";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import type { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
|
||||
/**
|
|
@ -1,8 +1,8 @@
|
|||
import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { UserInputError } from "../../../shared/errors";
|
||||
import { config } from "../../../../config";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { UserInputError } from "../../../../shared/errors";
|
||||
import { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
|
||||
const rejectedClients = new Map<string, number>();
|
|
@ -1,6 +1,6 @@
|
|||
import { Request } from "express";
|
||||
import { APIFormat, LLMService } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { APIFormat, LLMService } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const setApiFormat = (api: {
|
||||
inApi: Request["inboundApi"];
|
|
@ -2,8 +2,8 @@ import express from "express";
|
|||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
import { keyPool } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { AnthropicV1CompleteSchema } from "./transform-outbound-payload";
|
||||
|
||||
const AMZ_HOST =
|
|
@ -1,9 +1,9 @@
|
|||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { APIFormat } from "../../../shared/key-management";
|
||||
import { config } from "../../../../config";
|
||||
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
|
||||
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
|
||||
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
|
|
@ -1,8 +1,8 @@
|
|||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { config } from "../../../../config";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
|
@ -1,42 +0,0 @@
|
|||
import { Request } from "express";
|
||||
import { ClientRequest } from "http";
|
||||
import httpProxy from "http-proxy";
|
||||
import { ProxyRequestMiddleware } from "./index";
|
||||
|
||||
type ProxyReqCallback = httpProxy.ProxyReqCallback<ClientRequest, Request>;
|
||||
type RewriterOptions = {
|
||||
beforeRewrite?: ProxyReqCallback[];
|
||||
pipeline: ProxyRequestMiddleware[];
|
||||
};
|
||||
|
||||
export const createOnProxyReqHandler = ({
|
||||
beforeRewrite = [],
|
||||
pipeline,
|
||||
}: RewriterOptions): ProxyReqCallback => {
|
||||
return (proxyReq, req, res, options) => {
|
||||
// The streaming flag must be set before any other middleware runs, because
|
||||
// it may influence which other middleware a particular API pipeline wants
|
||||
// to run.
|
||||
// Image generation requests can't be streamed.
|
||||
req.isStreaming = req.body.stream === true || req.body.stream === "true";
|
||||
req.body.stream = req.isStreaming;
|
||||
|
||||
try {
|
||||
for (const validator of beforeRewrite) {
|
||||
validator(proxyReq, req, res, options);
|
||||
}
|
||||
} catch (error) {
|
||||
req.log.error(error, "Error while executing proxy request validator");
|
||||
proxyReq.destroy(error);
|
||||
}
|
||||
|
||||
try {
|
||||
for (const rewriter of pipeline) {
|
||||
rewriter(proxyReq, req, res, options);
|
||||
}
|
||||
} catch (error) {
|
||||
req.log.error(error, "Error while executing proxy request rewriter");
|
||||
proxyReq.destroy(error);
|
||||
}
|
||||
};
|
||||
};
|
|
@ -9,7 +9,7 @@ import {
|
|||
} from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { OpenAIChatMessage } from "../request/transform-outbound-payload";
|
||||
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
|
|
|
@ -7,11 +7,8 @@ import { ipLimiter } from "./rate-limit";
|
|||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
stripHeaders,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
|
@ -113,15 +110,7 @@ const openaiImagesProxy = createQueueMiddleware({
|
|||
"^/v1/chat/completions": "/v1/images/generations",
|
||||
},
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
],
|
||||
}),
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [addKey, finalizeBody] }),
|
||||
proxyRes: createOnProxyResHandler([openaiImagesResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
|
|
|
@ -2,7 +2,11 @@ import { RequestHandler, Router } from "express";
|
|||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import { getOpenAIModelFamily, ModelFamily, OpenAIModelFamily } from "../shared/models";
|
||||
import {
|
||||
getOpenAIModelFamily,
|
||||
ModelFamily,
|
||||
OpenAIModelFamily,
|
||||
} from "../shared/models";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
|
@ -10,18 +14,17 @@ import { handleProxyError } from "./middleware/common";
|
|||
import {
|
||||
addKey,
|
||||
addKeyForEmbeddingsRequest,
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
createEmbeddingsPreprocessorMiddleware,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
forceModel,
|
||||
limitCompletions,
|
||||
RequestPreprocessor,
|
||||
stripHeaders,
|
||||
} from "./middleware/request";
|
||||
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
export const KNOWN_OPENAI_MODELS = [
|
||||
|
@ -159,14 +162,7 @@ const openaiProxy = createQueueMiddleware({
|
|||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
limitCompletions,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
],
|
||||
pipeline: [addKey, finalizeBody],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
|
||||
error: handleProxyError,
|
||||
|
@ -181,7 +177,7 @@ const openaiEmbeddingsProxy = createProxyMiddleware({
|
|||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [addKeyForEmbeddingsRequest, stripHeaders, finalizeBody],
|
||||
pipeline: [addKeyForEmbeddingsRequest, finalizeBody],
|
||||
}),
|
||||
error: handleProxyError,
|
||||
},
|
||||
|
|
|
@ -9,13 +9,10 @@ import { ipLimiter } from "./rate-limit";
|
|||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
forceModel,
|
||||
stripHeaders,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
|
@ -149,14 +146,7 @@ const googlePalmProxy = createQueueMiddleware({
|
|||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
beforeRewrite: [reassignPathForPalmModel],
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
],
|
||||
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([palmResponseHandler]),
|
||||
error: handleProxyError,
|
||||
|
|
|
@ -14,17 +14,8 @@
|
|||
import crypto from "crypto";
|
||||
import type { Handler, Request } from "express";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import {
|
||||
getAwsBedrockModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGooglePalmModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
MODEL_FAMILIES,
|
||||
ModelFamily,
|
||||
} from "../shared/models";
|
||||
import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models";
|
||||
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
|
||||
import { assertNever } from "../shared/utils";
|
||||
import { logger } from "../logger";
|
||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||
import { RequestPreprocessor } from "./middleware/request";
|
||||
|
@ -132,34 +123,9 @@ export function enqueue(req: Request) {
|
|||
}
|
||||
}
|
||||
|
||||
function getPartitionForRequest(req: Request): ModelFamily {
|
||||
// There is a single request queue, but it is partitioned by model family.
|
||||
// Model families are typically separated on cost/rate limit boundaries so
|
||||
// they should be treated as separate queues.
|
||||
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||
|
||||
// Weird special case for AWS/Azure because they serve multiple models from
|
||||
// different vendors, even if currently only one is supported.
|
||||
if (req.service === "aws") return getAwsBedrockModelFamily(model);
|
||||
if (req.service === "azure") return getAzureOpenAIModelFamily(model);
|
||||
|
||||
switch (req.outboundApi) {
|
||||
case "anthropic":
|
||||
return getClaudeModelFamily(model);
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
return getOpenAIModelFamily(model);
|
||||
case "google-palm":
|
||||
return getGooglePalmModelFamily(model);
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
}
|
||||
|
||||
function getQueueForPartition(partition: ModelFamily): Request[] {
|
||||
return queue
|
||||
.filter((req) => getPartitionForRequest(req) === partition)
|
||||
.filter((req) => getModelFamilyForRequest(req) === partition)
|
||||
.sort((a, b) => {
|
||||
// Certain requests are exempted from IP-based rate limiting because they
|
||||
// come from a shared IP address. To prevent these requests from starving
|
||||
|
@ -222,7 +188,7 @@ function processQueue() {
|
|||
|
||||
reqs.filter(Boolean).forEach((req) => {
|
||||
if (req?.proceed) {
|
||||
const modelFamily = getPartitionForRequest(req!);
|
||||
const modelFamily = getModelFamilyForRequest(req!);
|
||||
req.log.info({
|
||||
retries: req.retryCount,
|
||||
partition: modelFamily,
|
||||
|
@ -279,7 +245,7 @@ let waitTimes: {
|
|||
/** Adds a successful request to the list of wait times. */
|
||||
export function trackWaitTime(req: Request) {
|
||||
waitTimes.push({
|
||||
partition: getPartitionForRequest(req),
|
||||
partition: getModelFamilyForRequest(req),
|
||||
start: req.startTime!,
|
||||
end: req.queueOutTime ?? Date.now(),
|
||||
isDeprioritized: isFromSharedIp(req),
|
||||
|
@ -324,7 +290,7 @@ function calculateWaitTime(partition: ModelFamily) {
|
|||
|
||||
const currentWaits = queue
|
||||
.filter((req) => {
|
||||
const isSamePartition = getPartitionForRequest(req) === partition;
|
||||
const isSamePartition = getModelFamilyForRequest(req) === partition;
|
||||
const isNormalPriority = !isFromSharedIp(req);
|
||||
return isSamePartition && isNormalPriority;
|
||||
})
|
||||
|
|
|
@ -170,12 +170,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
|||
throw new Error(`No keys available for model family '${neededFamily}'.`);
|
||||
}
|
||||
|
||||
if (!config.allowedModelFamilies.includes(neededFamily)) {
|
||||
throw new Error(
|
||||
`Proxy operator has disabled model family '${neededFamily}'.`
|
||||
);
|
||||
}
|
||||
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. We ignore rate limits from >30 seconds ago
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
// Don't import anything here, this is imported by config.ts
|
||||
|
||||
import pino from "pino";
|
||||
import type { Request } from "express";
|
||||
import { assertNever } from "./utils";
|
||||
|
||||
export type OpenAIModelFamily =
|
||||
| "turbo"
|
||||
|
@ -103,3 +105,38 @@ export function assertIsKnownModelFamily(
|
|||
throw new Error(`Unknown model family: ${modelFamily}`);
|
||||
}
|
||||
}
|
||||
|
||||
export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
if (req.modelFamily) return req.modelFamily;
|
||||
// There is a single request queue, but it is partitioned by model family.
|
||||
// Model families are typically separated on cost/rate limit boundaries so
|
||||
// they should be treated as separate queues.
|
||||
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||
let modelFamily: ModelFamily;
|
||||
|
||||
// Weird special case for AWS/Azure because they serve multiple models from
|
||||
// different vendors, even if currently only one is supported.
|
||||
if (req.service === "aws") {
|
||||
modelFamily = getAwsBedrockModelFamily(model);
|
||||
} else if (req.service === "azure") {
|
||||
modelFamily = getAzureOpenAIModelFamily(model);
|
||||
} else {
|
||||
switch (req.outboundApi) {
|
||||
case "anthropic":
|
||||
modelFamily = getClaudeModelFamily(model);
|
||||
break;
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
modelFamily = getOpenAIModelFamily(model);
|
||||
break;
|
||||
case "google-palm":
|
||||
modelFamily = getGooglePalmModelFamily(model);
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
}
|
||||
|
||||
return (req.modelFamily = modelFamily);
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ import { Tiktoken } from "tiktoken/lite";
|
|||
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
|
||||
import { logger } from "../../logger";
|
||||
import { libSharp } from "../file-storage";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
|
||||
const log = logger.child({ module: "tokenizer", service: "openai" });
|
||||
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { Request } from "express";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/transform-outbound-payload";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import { assertNever } from "../utils";
|
||||
import {
|
||||
init as initClaude,
|
||||
|
|
|
@ -2,6 +2,7 @@ import type { HttpRequest } from "@smithy/types";
|
|||
import { Express } from "express-serve-static-core";
|
||||
import { APIFormat, Key, LLMService } from "../shared/key-management";
|
||||
import { User } from "../shared/users/schema";
|
||||
import { ModelFamily } from "../shared/models";
|
||||
|
||||
declare global {
|
||||
namespace Express {
|
||||
|
@ -27,6 +28,7 @@ declare global {
|
|||
outputTokens?: number;
|
||||
tokenizerInfo: Record<string, any>;
|
||||
signedRequest: HttpRequest;
|
||||
modelFamily?: ModelFamily;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue