From 51ffca480af7f69427dca1a369f33a73ff2dcc59 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Mon, 4 Mar 2024 16:25:06 -0600 Subject: [PATCH] adds AWS Claude Chat Completions and Claude 3 Sonnet support --- src/info-page.ts | 2 +- src/proxy/anthropic.ts | 2 +- src/proxy/aws.ts | 71 ++++++++++++++++--- .../request/preprocessors/sign-aws-request.ts | 49 +++++++------ src/proxy/middleware/response/index.ts | 15 ++-- .../response/streaming/sse-stream-adapter.ts | 11 ++- src/service-info.ts | 8 ++- src/shared/key-management/aws/checker.ts | 31 ++++++-- src/shared/key-management/aws/provider.ts | 15 +++- 9 files changed, 155 insertions(+), 49 deletions(-) diff --git a/src/info-page.ts b/src/info-page.ts index c14780c..f2e2eb1 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -24,7 +24,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { "mistral-small": "Mixtral Small", // Originally 8x7B, but that now refers to the older open-weight version. Mixtral Small is a newer closed-weight update to the 8x7B model. "mistral-medium": "Mistral Medium", "mistral-large": "Mistral Large", - "aws-claude": "AWS Claude", + "aws-claude": "AWS Claude (Sonnet)", "azure-turbo": "Azure GPT-3.5 Turbo", "azure-gpt4": "Azure GPT-4", "azure-gpt4-32k": "Azure GPT-4 32k", diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 93b4926..a17122a 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -105,7 +105,7 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async ( res.status(200).json(body); }; -function transformAnthropicChatResponseToAnthropicText( +export function transformAnthropicChatResponseToAnthropicText( anthropicBody: Record, req: Request ): Record { diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index d79155f..ee3fcde 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -16,8 +16,10 @@ import { ProxyResHandlerWithBody, createOnProxyResHandler, } from "./middleware/response"; +import { transformAnthropicChatResponseToAnthropicText } from "./anthropic"; const LATEST_AWS_V2_MINOR_VERSION = "1"; +const CLAUDE_3_COMPAT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"; let modelsCache: any = null; let modelsCacheTime = 0; @@ -29,10 +31,11 @@ const getModelsResponse = () => { if (!config.awsCredentials) return { object: "list", data: [] }; + // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html const variants = [ "anthropic.claude-v2", "anthropic.claude-v2:1", - "anthropic.claude-3-sonnet-20240229-v1:0" + "anthropic.claude-3-sonnet-20240229-v1:0", ]; const models = variants.map((id) => ({ @@ -73,7 +76,12 @@ const awsResponseHandler: ProxyResHandlerWithBody = async ( if (req.inboundApi === "openai") { req.log.info("Transforming AWS Claude response to OpenAI format"); - body = transformAwsResponse(body, req); + body = transformAwsTextResponseToOpenAI(body, req); + } + + if (req.inboundApi === "anthropic-text") { + req.log.info("Transforming Text AWS Claude response to Chat format"); + body = transformAnthropicChatResponseToAnthropicText(body, req); } if (req.tokenizerInfo) { @@ -92,7 +100,7 @@ const awsResponseHandler: ProxyResHandlerWithBody = async ( * is only used for non-streaming requests as streaming requests are handled * on-the-fly. */ -function transformAwsResponse( +function transformAwsTextResponseToOpenAI( awsBody: Record, req: Request ): Record { @@ -139,18 +147,54 @@ const awsProxy = createQueueMiddleware({ }), }); +const nativeTextPreprocessor = createPreprocessorMiddleware( + { inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +const textToChatPreprocessor = createPreprocessorMiddleware( + { inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +/** + * Routes text completion prompts to aws anthropic-chat if they need translation + * (claude-3 based models do not support the old text completion endpoint). + */ +const awsTextCompletionRouter: RequestHandler = (req, res, next) => { + if (req.body.model?.includes("claude-3")) { + textToChatPreprocessor(req, res, next); + } else { + nativeTextPreprocessor(req, res, next); + } +}; + const awsRouter = Router(); awsRouter.get("/v1/models", handleModelRequest); -// Native(ish) Anthropic chat completion endpoint. +// Native(ish) Anthropic text completion endpoint. +awsRouter.post("/v1/complete", ipLimiter, awsTextCompletionRouter, awsProxy); +// Native Anthropic chat completion endpoint. awsRouter.post( - "/v1/complete", + "/v1/messages", ipLimiter, createPreprocessorMiddleware( - { inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" }, + { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" }, { afterTransform: [maybeReassignModel] } ), awsProxy ); +// Temporary force-Claude3 endpoint +awsRouter.post( + "/v1/claude-3/complete", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" }, + { + beforeTransform: [(req) => void (req.body.model = CLAUDE_3_COMPAT_MODEL)], + } + ), + awsProxy +); // OpenAI-to-AWS Anthropic compatibility endpoint. awsRouter.post( "/v1/chat/completions", @@ -178,7 +222,8 @@ function maybeReassignModel(req: Request) { return; } - const pattern = /^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?$/i; + const pattern = + /^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?(-sonnet-?|-opus-?)(\d*)/i; const match = model.match(pattern); // If there's no match, return the latest v2 model @@ -187,7 +232,9 @@ function maybeReassignModel(req: Request) { return; } - const [, , instant, , major, , minor] = match; + const instant = match[2]; + const major = match[4]; + const minor = match[6]; if (instant) { req.body.model = "anthropic.claude-instant-v1"; @@ -210,6 +257,14 @@ function maybeReassignModel(req: Request) { return; } + // AWS currently only supports one v3 model. + const variant = match[8]; // sonnet or opus + const variantVersion = match[9]; + if (major === "3") { + req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0"; + return; + } + // Fallback to latest v2 model req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`; return; diff --git a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts index afe3e02..766f7d2 100644 --- a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts +++ b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts @@ -15,15 +15,19 @@ const AMZ_HOST = /** * Signs an outgoing AWS request with the appropriate headers modifies the * request object in place to fix the path. + * This happens AFTER request transformation. */ export const signAwsRequest: RequestPreprocessor = async (req) => { - req.key = keyPool.get("anthropic.claude-v2", "aws"); - const { model, stream } = req.body; + req.key = keyPool.get(model, "aws"); + req.isStreaming = stream === true || stream === "true"; - let preamble = req.body.prompt.startsWith("\n\nHuman:") ? "" : "\n\nHuman:"; - req.body.prompt = preamble + req.body.prompt; + // same as addAnthropicPreamble for non-AWS requests, but has to happen here + if (req.outboundApi === "anthropic-text") { + let preamble = req.body.prompt.startsWith("\n\nHuman:") ? "" : "\n\nHuman:"; + req.body.prompt = preamble + req.body.prompt; + } // AWS uses mostly the same parameters as Anthropic, with a few removed params // and much stricter validation on unused parameters. Rather than treating it @@ -31,28 +35,27 @@ export const signAwsRequest: RequestPreprocessor = async (req) => { // parameters. // TODO: This should happen in transform-outbound-payload.ts let strippedParams: Record; - if (req.inboundApi === "anthropic-chat") { - strippedParams = AnthropicV1MessagesSchema - .pick({ - messages: true, - max_tokens: true, - stop_sequences: true, - temperature: true, - top_k: true, - top_p: true, - }) + if (req.outboundApi === "anthropic-chat") { + strippedParams = AnthropicV1MessagesSchema.pick({ + messages: true, + max_tokens: true, + stop_sequences: true, + temperature: true, + top_k: true, + top_p: true, + }) .strip() .parse(req.body); + strippedParams.anthropic_version = "bedrock-2023-05-31"; } else { - strippedParams = AnthropicV1TextSchema - .pick({ - prompt: true, - max_tokens_to_sample: true, - stop_sequences: true, - temperature: true, - top_k: true, - top_p: true, - }) + strippedParams = AnthropicV1TextSchema.pick({ + prompt: true, + max_tokens_to_sample: true, + stop_sequences: true, + temperature: true, + top_k: true, + top_p: true, + }) .strip() .parse(req.body); } diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index 3b8ca22..cb8044b 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -332,12 +332,17 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`; break; case "AccessDeniedException": - req.log.error( - { key: req.key?.hash, model: req.body?.model }, - "Disabling key due to AccessDeniedException when invoking model. If credentials are valid, check IAM permissions." + const isModelAccessError = errorPayload.error?.message?.includes( + `access to the model with the specified model ID` ); - keyPool.disable(req.key!, "revoked"); - errorPayload.proxy_note = `API key doesn't have access to the requested resource.`; + if (!isModelAccessError) { + req.log.error( + { key: req.key?.hash, model: req.body?.model }, + "Disabling key due to AccessDeniedException when invoking model. If credentials are valid, check IAM permissions." + ); + keyPool.disable(req.key!, "revoked"); + } + errorPayload.proxy_note = `API key doesn't have access to the requested resource. Model ID: ${req.body?.model}`; break; default: errorPayload.proxy_note = `Received 403 error. Key may be invalid.`; diff --git a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts index fb914a3..83c3f1e 100644 --- a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts +++ b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts @@ -49,7 +49,16 @@ export class SSEStreamAdapter extends Transform { if (contentType === "application/json" && eventType === "chunk") { const { bytes } = JSON.parse(bodyStr); const event = Buffer.from(bytes, "base64").toString("utf8"); - return ["event: completion", `data: ${event}`].join(`\n`); + const eventObj = JSON.parse(event); + + if ('completion' in eventObj) { + return ["event: completion", `data: ${event}`].join(`\n`); + } else { + return [ + `event: ${eventObj.type}`, + `data: ${event}`, + ].join(`\n`); + } } // Intentional fallthrough, as non-JSON events may as well be errors // noinspection FallThroughInSwitchStatementJS diff --git a/src/service-info.ts b/src/service-info.ts index ba1a0c3..7d09ba2 100644 --- a/src/service-info.ts +++ b/src/service-info.ts @@ -51,6 +51,7 @@ type ModelAggregates = { overQuota?: number; pozzed?: number; awsLogged?: number; + awsSonnet?: number; queued: number; queueTime: string; tokens: number; @@ -81,7 +82,7 @@ type AnthropicInfo = BaseFamilyInfo & { prefilledKeys?: number; overQuotaKeys?: number; }; -type AwsInfo = BaseFamilyInfo & { privacy?: string }; +type AwsInfo = BaseFamilyInfo & { privacy?: string; sonnetKeys?: number }; // prettier-ignore export type ServiceInfo = { @@ -133,7 +134,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { }, anthropic: { anthropic: `%BASE%/anthropic`, - "anthropic-claude-3 (temporary compatibility endpoint)": `%BASE%/anthropic/claude-3`, + "anthropic-claude-3 (⚠️temporary compatibility endpoint)": `%BASE%/anthropic/claude-3`, }, "google-ai": { "google-ai": `%BASE%/google-ai`, @@ -143,6 +144,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { }, aws: { aws: `%BASE%/aws/claude`, + "aws-claude-3 (⚠️temporary compatibility endpoint)": `%BASE%/aws/claude/claude-3`, }, azure: { azure: `%BASE%/azure/openai`, @@ -372,6 +374,7 @@ function addKeyToAggregates(k: KeyPoolKey) { increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]); + increment(modelStats, `${family}__awsSonnet`, k.sonnetEnabled ? 1 : 0); // Ignore revoked keys for aws logging stats, but include keys where the // logging status is unknown. @@ -419,6 +422,7 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0; break; case "aws": + info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0; const logged = modelStats.get(`${family}__awsLogged`) || 0; if (logged > 0) { info.privacy = config.allowAwsLogging diff --git a/src/shared/key-management/aws/checker.ts b/src/shared/key-management/aws/checker.ts index 1364089..ab58e8d 100644 --- a/src/shared/key-management/aws/checker.ts +++ b/src/shared/key-management/aws/checker.ts @@ -15,7 +15,10 @@ const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) => `https://bedrock.${region}.amazonaws.com/logging/modelinvocations`; const POST_INVOKE_MODEL_URL = (region: string, model: string) => `https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`; -const TEST_PROMPT = "\n\nHuman:\n\nAssistant:"; +const TEST_MESSAGES = [ + { role: "user", content: "Hi!" }, + { role: "assistant", content: "Hello!" }, +]; type AwsError = { error: {} }; @@ -47,8 +50,10 @@ export class AwsKeyChecker extends KeyCheckerBase { const modelChecks: Promise[] = []; const isInitialCheck = !key.lastChecked; if (isInitialCheck) { - modelChecks.push(this.invokeModel("anthropic.claude-v2", key)); modelChecks.push(this.invokeModel("anthropic.claude-v2:1", key)); + modelChecks.push( + this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key) + ); } await Promise.all(modelChecks); @@ -128,12 +133,18 @@ export class AwsKeyChecker extends KeyCheckerBase { const creds = AwsKeyChecker.getCredentialsFromKey(key); // This is not a valid invocation payload, but a 400 response indicates that // the principal at least has permission to invoke the model. - const payload = { max_tokens_to_sample: -1, prompt: TEST_PROMPT }; + // A 403 response indicates that the model is not accessible -- if none of + // the models are accessible, the key is effectively disabled. + const payload = { + max_tokens: -1, + messages: TEST_MESSAGES, + anthropic_version: "bedrock-2023-05-31", + }; const config: AxiosRequestConfig = { method: "POST", url: POST_INVOKE_MODEL_URL(creds.region, model), data: payload, - validateStatus: (status) => status === 400, + validateStatus: (status) => status === 400 || status === 403, }; config.headers = new AxiosHeaders({ "content-type": "application/json", @@ -145,10 +156,20 @@ export class AwsKeyChecker extends KeyCheckerBase { const errorType = (headers["x-amzn-errortype"] as string).split(":")[0]; const errorMessage = data?.message; + // We only allow one type of 403 error, and we only allow it for one model. + if (status === 403 && errorMessage?.match(/access to the model with the specified model ID/)) { + this.log.warn( + { key: key.hash, errorType, data, status, model }, + "Key does not have access to Claude 3 Sonnet." + ); + this.updateKey(key.hash, { sonnetEnabled: false }); + return; + } + // We're looking for a specific error type and message here // "ValidationException" const correctErrorType = errorType === "ValidationException"; - const correctErrorMessage = errorMessage?.match(/max_tokens_to_sample/); + const correctErrorMessage = errorMessage?.match(/max_tokens/); if (!correctErrorType || !correctErrorMessage) { throw new AxiosError( `Unexpected error when invoking model ${model}: ${errorMessage}`, diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index bf1e950..c874138 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -29,6 +29,7 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage { * set. */ awsLoggingStatus: "unknown" | "disabled" | "enabled"; + sonnetEnabled: boolean; } /** @@ -78,6 +79,7 @@ export class AwsBedrockKeyProvider implements KeyProvider { .digest("hex") .slice(0, 8)}`, lastChecked: 0, + sonnetEnabled: true, ["aws-claudeTokens"]: 0, }; this.keys.push(newKey); @@ -96,13 +98,20 @@ export class AwsBedrockKeyProvider implements KeyProvider { return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); } - public get(_model: AwsBedrockModel) { + public get(model: AwsBedrockModel) { const availableKeys = this.keys.filter((k) => { const isNotLogged = k.awsLoggingStatus === "disabled"; - return !k.isDisabled && (isNotLogged || config.allowAwsLogging); + const needsSonnet = model.includes("sonnet"); + return ( + !k.isDisabled && + (isNotLogged || config.allowAwsLogging) && + (k.sonnetEnabled || !needsSonnet) + ); }); if (availableKeys.length === 0) { - throw new Error("No AWS Bedrock keys available"); + throw new Error( + "No keys available for this model. If you are requesting Sonnet, use Claude-2 instead." + ); } // (largely copied from the OpenAI provider, without trial key support)