From 96fe974ad0d37c6c83379f6d8ebd531883098de9 Mon Sep 17 00:00:00 2001 From: khanon Date: Sun, 1 Sep 2024 22:55:07 +0000 Subject: [PATCH] Use AWS Inference Profiles for higher rate limits (khanon/oai-reverse-proxy!78) --- package-lock.json | 90 +++++++++++- package.json | 2 +- .../request/preprocessors/sign-aws-request.ts | 20 ++- src/shared/key-management/aws/checker.ts | 132 +++++++++++++----- src/shared/key-management/aws/provider.ts | 18 ++- src/shared/key-management/prioritize-keys.ts | 29 +++- 6 files changed, 236 insertions(+), 55 deletions(-) diff --git a/package-lock.json b/package-lock.json index a7fea0c..f3d331f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17,7 +17,6 @@ "@smithy/eventstream-serde-node": "^2.1.3", "@smithy/protocol-http": "^3.2.1", "@smithy/signature-v4": "^2.1.3", - "@smithy/types": "^2.10.1", "@smithy/util-utf8": "^2.1.1", "axios": "^1.7.4", "better-sqlite3": "^10.0.0", @@ -52,6 +51,7 @@ "zod-error": "^1.5.0" }, "devDependencies": { + "@smithy/types": "^3.3.0", "@types/better-sqlite3": "^7.6.10", "@types/cookie-parser": "^1.4.3", "@types/cors": "^2.8.13", @@ -152,6 +152,17 @@ "node": ">=14.0.0" } }, + "node_modules/@aws-sdk/types/node_modules/@smithy/types": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz", + "integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@aws-sdk/util-utf8-browser": { "version": "3.259.0", "resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz", @@ -1328,6 +1339,17 @@ "tslib": "^2.5.0" } }, + "node_modules/@smithy/eventstream-codec/node_modules/@smithy/types": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz", + "integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@smithy/eventstream-serde-node": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-node/-/eventstream-serde-node-2.1.3.tgz", @@ -1341,6 +1363,17 @@ "node": ">=14.0.0" } }, + "node_modules/@smithy/eventstream-serde-node/node_modules/@smithy/types": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz", + "integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@smithy/eventstream-serde-universal": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-2.1.3.tgz", @@ -1354,6 +1387,17 @@ "node": ">=14.0.0" } }, + "node_modules/@smithy/eventstream-serde-universal/node_modules/@smithy/types": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz", + "integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@smithy/is-array-buffer": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.1.1.tgz", @@ -1377,6 +1421,17 @@ "node": ">=14.0.0" } }, + "node_modules/@smithy/protocol-http/node_modules/@smithy/types": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz", + "integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@smithy/signature-v4": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-2.1.3.tgz", @@ -1395,17 +1450,29 @@ "node": ">=14.0.0" } }, - "node_modules/@smithy/types": { - "version": "2.10.1", - "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.10.1.tgz", - "integrity": "sha512-hjQO+4ru4cQ58FluQvKKiyMsFg0A6iRpGm2kqdH8fniyNd2WyanoOsYJfMX/IFLuLxEoW6gnRkNZy1y6fUUhtA==", + "node_modules/@smithy/signature-v4/node_modules/@smithy/types": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz", + "integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==", "dependencies": { - "tslib": "^2.5.0" + "tslib": "^2.6.2" }, "engines": { "node": ">=14.0.0" } }, + "node_modules/@smithy/types": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-3.3.0.tgz", + "integrity": "sha512-IxvBBCTFDHbVoK7zIxqA1ZOdc4QfM5HM7rGleCuHi7L1wnKv5Pn69xXJQ9hgxH60ZVygH9/JG0jRgtUncE3QUA==", + "dev": true, + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=16.0.0" + } + }, "node_modules/@smithy/util-buffer-from": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.1.1.tgz", @@ -1441,6 +1508,17 @@ "node": ">=14.0.0" } }, + "node_modules/@smithy/util-middleware/node_modules/@smithy/types": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz", + "integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@smithy/util-uri-escape": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-2.1.1.tgz", diff --git a/package.json b/package.json index 852824e..01abfab 100644 --- a/package.json +++ b/package.json @@ -26,7 +26,6 @@ "@smithy/eventstream-serde-node": "^2.1.3", "@smithy/protocol-http": "^3.2.1", "@smithy/signature-v4": "^2.1.3", - "@smithy/types": "^2.10.1", "@smithy/util-utf8": "^2.1.1", "axios": "^1.7.4", "better-sqlite3": "^10.0.0", @@ -61,6 +60,7 @@ "zod-error": "^1.5.0" }, "devDependencies": { + "@smithy/types": "^3.3.0", "@types/better-sqlite3": "^7.6.10", "@types/cookie-parser": "^1.4.3", "@types/cors": "^2.8.13", diff --git a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts index d27c058..5a937bc 100644 --- a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts +++ b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts @@ -6,7 +6,7 @@ import { AnthropicV1TextSchema, AnthropicV1MessagesSchema, } from "../../../../shared/api-schemas"; -import { keyPool } from "../../../../shared/key-management"; +import { AwsBedrockKey, keyPool } from "../../../../shared/key-management"; import { RequestPreprocessor } from "../index"; import { AWSMistralV1ChatCompletionsSchema, @@ -40,13 +40,21 @@ export const signAwsRequest: RequestPreprocessor = async (req) => { // set it so that the stream adapter always selects the correct transformer. req.headers["anthropic-version"] = "2023-06-01"; + // If our key has an inference profile compatible with the requested model, + // we want to use the inference profile instead of the model ID when calling + // InvokeModel as that will give us higher rate limits. + const profile = + (req.key as AwsBedrockKey).inferenceProfileIds.find((p) => + p.includes(model) + ) || model; + // Uses the AWS SDK to sign a request, then modifies our HPM proxy request // with the headers generated by the SDK. const newRequest = new HttpRequest({ method: "POST", protocol: "https:", hostname: host, - path: `/model/${model}/invoke${stream ? "-with-response-stream" : ""}`, + path: `/model/${profile}/invoke${stream ? "-with-response-stream" : ""}`, headers: { ["Host"]: host, ["content-type"]: "application/json", @@ -62,7 +70,13 @@ export const signAwsRequest: RequestPreprocessor = async (req) => { const { key, body, inboundApi, outboundApi } = req; req.log.info( - { key: key.hash, model: body.model, inboundApi, outboundApi }, + { + key: key.hash, + model: body.model, + inferenceProfile: profile, + inboundApi, + outboundApi, + }, "Assigned AWS credentials to request" ); diff --git a/src/shared/key-management/aws/checker.ts b/src/shared/key-management/aws/checker.ts index b742d84..09e209a 100644 --- a/src/shared/key-management/aws/checker.ts +++ b/src/shared/key-management/aws/checker.ts @@ -1,12 +1,12 @@ import { Sha256 } from "@aws-crypto/sha256-js"; import { SignatureV4 } from "@smithy/signature-v4"; import { HttpRequest } from "@smithy/protocol-http"; -import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios"; +import axios, { AxiosError, AxiosHeaders, AxiosRequestConfig } from "axios"; import { URL } from "url"; +import { config } from "../../../config"; +import { getAwsBedrockModelFamily } from "../../models"; import { KeyCheckerBase } from "../key-checker-base"; import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider"; -import { getAwsBedrockModelFamily } from "../../models"; -import { config } from "../../../config"; type ParentModelId = string; type AliasModelId = string; @@ -24,6 +24,7 @@ const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [ ["mistral.mistral-large-2407-v1:0"], ["mistral.mistral-small-2402-v1:0"], // Seems to return 400 ]; + const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes const AMZ_HOST = @@ -31,6 +32,8 @@ const AMZ_HOST = const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`; const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) => `https://bedrock.${region}.amazonaws.com/logging/modelinvocations`; +const GET_LIST_INFERENCE_PROFILES_URL = (region: string) => + `https://bedrock.${region}.amazonaws.com/inference-profiles?maxResults=1000`; const POST_INVOKE_MODEL_URL = (region: string, model: string) => `https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`; const TEST_MESSAGES = [ @@ -40,6 +43,22 @@ const TEST_MESSAGES = [ type AwsError = { error: {} }; +type GetInferenceProfilesResponse = { + inferenceProfileSummaries: { + inferenceProfileId: string; + inferenceProfileName: string; + inferenceProfileArn: string; + description?: string; + createdAt?: string; + updatedAt?: string; + status: "ACTIVE" | unknown; + type: "SYSTEM_DEFINED" | unknown; + models: { + modelArn?: string; + }[]; + }[]; +}; + type GetLoggingConfigResponse = { loggingConfig: null | { cloudWatchConfig: null | unknown; @@ -66,38 +85,52 @@ export class AwsKeyChecker extends KeyCheckerBase { const isInitialCheck = !key.lastChecked; if (isInitialCheck) { - // Perform checks for all parent model IDs - const results = await Promise.all( - KNOWN_MODEL_IDS.filter(([model]) => - // Skip checks for models that are disabled anyway - config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model)) - ).map(async ([model, ...aliases]) => ({ - models: [model, ...aliases], - success: await this.invokeModel(model, key), - })) - ); - - // Filter out models that are disabled - const modelIds = results - .filter(({ success }) => success) - .flatMap(({ models }) => models); - - if (modelIds.length === 0) { + try { + await this.checkInferenceProfiles(key); + } catch (e) { + const asError = e as AxiosError; + const data = asError.response?.data; this.log.warn( - { key: key.hash }, - "Key does not have access to any models; disabling." + { key: key.hash, error: e.message, data }, + "Cannot list inference profiles.\n\ +Principal may be missing `AmazonBedrockFullAccess`, or has no policy allowing action `bedrock:ListInferenceProfiles` against resource `arn:aws:bedrock:*:*:inference-profile/*`.\n\ +Requests will be made without inference profiles using on-demand quotas, which may be subject to more restrictive rate limits.\n\ +See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-prereq.html." ); - return this.updateKey(key.hash, { isDisabled: true }); } - - this.updateKey(key.hash, { - modelIds, - modelFamilies: Array.from( - new Set(modelIds.map(getAwsBedrockModelFamily)) - ), - }); } + // Perform checks for all parent model IDs + const results = await Promise.all( + KNOWN_MODEL_IDS.filter(([model]) => + // Skip checks for models that are disabled anyway + config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model)) + ).map(async ([model, ...aliases]) => ({ + models: [model, ...aliases], + success: await this.invokeModel(model, key), + })) + ); + + // Filter out models that are disabled + const modelIds = results + .filter(({ success }) => success) + .flatMap(({ models }) => models); + + if (modelIds.length === 0) { + this.log.warn( + { key: key.hash }, + "Key does not have access to any models; disabling." + ); + return this.updateKey(key.hash, { isDisabled: true }); + } + + this.updateKey(key.hash, { + modelIds, + modelFamilies: Array.from( + new Set(modelIds.map(getAwsBedrockModelFamily)) + ), + }); + this.log.info( { key: key.hash, @@ -222,6 +255,10 @@ export class AwsKeyChecker extends KeyCheckerBase { status === 403 && errorMessage?.match(/access to the model with the specified model ID/) ) { + this.log.debug( + { key: key.hash, model, errorType, data, status, headers }, + "Model is not available (principal does not have access)." + ); return false; } @@ -230,7 +267,7 @@ export class AwsKeyChecker extends KeyCheckerBase { if (status === 404) { this.log.debug( { region: creds.region, model, key: key.hash }, - "Model not supported in this AWS region." + "Model is not available (not supported in this AWS region)." ); return false; } @@ -242,14 +279,14 @@ export class AwsKeyChecker extends KeyCheckerBase { if (!correctErrorType || !correctErrorMessage) { this.log.debug( { key: key.hash, model, errorType, data, status }, - "AWS InvokeModel test unsuccessful." + "Model is not available (request rejected)." ); return false; } this.log.debug( { key: key.hash, model, errorType, data, status }, - "AWS InvokeModel test successful." + "Model is available." ); return true; } @@ -283,7 +320,7 @@ export class AwsKeyChecker extends KeyCheckerBase { if (status === 403 || status === 404) { this.log.debug( { key: key.hash, model, errorType, data, status }, - "AWS InvokeModel test returned 403 or 404." + "Model is not available (no access or unsupported region)." ); return false; } @@ -293,18 +330,38 @@ export class AwsKeyChecker extends KeyCheckerBase { if (isBadRequest && !isValidationError) { this.log.debug( { key: key.hash, model, errorType, data, status, headers }, - "AWS InvokeModel test returned 400 but not a validation error." + "Model is not available (request rejected)." ); return false; } this.log.debug( { key: key.hash, model, errorType, data, status }, - "AWS InvokeModel test successful." + "Model is available." ); return true; } + private async checkInferenceProfiles(key: AwsBedrockKey) { + const creds = AwsKeyChecker.getCredentialsFromKey(key); + const req: AxiosRequestConfig = { + method: "GET", + url: GET_LIST_INFERENCE_PROFILES_URL(creds.region), + headers: { accept: "application/json" }, + }; + await AwsKeyChecker.signRequestForAws(req, key); + const { data } = await axios.request(req); + const { inferenceProfileSummaries } = data; + const profileIds = inferenceProfileSummaries.map( + (p) => p.inferenceProfileId + ); + this.log.debug( + { key: key.hash, profileIds, region: creds.region }, + "Inference profiles found." + ); + this.updateKey(key.hash, { inferenceProfileIds: profileIds }); + } + private async checkLoggingConfiguration(key: AwsBedrockKey) { if (config.allowAwsLogging) { // Don't check logging status if we're allowing it to reduce API calls. @@ -373,7 +430,8 @@ export class AwsKeyChecker extends KeyCheckerBase { method, protocol: "https:", hostname: url.hostname, - path: url.pathname + url.search, + path: url.pathname, + query: Object.fromEntries(url.searchParams), headers: { Host: url.hostname, ...plainHeaders }, }); diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index 76d10b5..c6f0c6a 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -22,6 +22,7 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage { */ awsLoggingStatus: "unknown" | "disabled" | "enabled"; modelIds: string[]; + inferenceProfileIds: string[]; } /** @@ -72,6 +73,7 @@ export class AwsBedrockKeyProvider implements KeyProvider { .slice(0, 8)}`, lastChecked: 0, modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"], + inferenceProfileIds: [], ["aws-claudeTokens"]: 0, ["aws-claude-opusTokens"]: 0, ["aws-mistral-tinyTokens"]: 0, @@ -135,7 +137,21 @@ export class AwsBedrockKeyProvider implements KeyProvider { ); } - const selectedKey = prioritizeKeys(availableKeys)[0]; + /** + * Comparator for prioritizing keys on inference profile compatibility. + * Requests made via inference profiles have higher rate limits so we want + * to use keys with compatible inference profiles first. + */ + const hasInferenceProfile = ( + a: AwsBedrockKey, + b: AwsBedrockKey + ) => { + const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model)); + const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model)); + return aMatch - bMatch; + }; + + const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0]; selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; diff --git a/src/shared/key-management/prioritize-keys.ts b/src/shared/key-management/prioritize-keys.ts index cf52995..60729ae 100644 --- a/src/shared/key-management/prioritize-keys.ts +++ b/src/shared/key-management/prioritize-keys.ts @@ -1,12 +1,22 @@ import { Key } from "./index"; -export function prioritizeKeys(keys: T[]) { - // Sorts keys from highest priority to lowest priority, where priority is: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 2. Keys which have not been used in the longest time - +/** + * Given a list of keys, returns a new list of keys sorted from highest to + * lowest priority. Keys are prioritized in the following order: + * + * 1. Keys which are not rate limited + * a. If all keys were rate limited recently, select the least-recently + * rate limited key. + * b. Otherwise, select the first key. + * 2. Keys which have not been used in the longest time + * 3. Keys according to the custom comparator, if provided + * @param keys The list of keys to sort + * @param customComparator A custom comparator function to use for sorting + */ +export function prioritizeKeys( + keys: T[], + customComparator?: (a: T, b: T) => number +) { const now = Date.now(); return keys.sort((a, b) => { @@ -19,6 +29,11 @@ export function prioritizeKeys(keys: T[]) { return a.rateLimitedAt - b.rateLimitedAt; } + if (customComparator) { + const result = customComparator(a, b); + if (result !== 0) return result; + } + return a.lastUsed - b.lastUsed; }); }