adds AWS Claude Chat Completions and Claude 3 Sonnet support

This commit is contained in:
nai-degen 2024-03-04 16:25:06 -06:00
parent 802d847cc6
commit 51ffca480a
9 changed files with 155 additions and 49 deletions

View File

@ -24,7 +24,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"mistral-small": "Mixtral Small", // Originally 8x7B, but that now refers to the older open-weight version. Mixtral Small is a newer closed-weight update to the 8x7B model. "mistral-small": "Mixtral Small", // Originally 8x7B, but that now refers to the older open-weight version. Mixtral Small is a newer closed-weight update to the 8x7B model.
"mistral-medium": "Mistral Medium", "mistral-medium": "Mistral Medium",
"mistral-large": "Mistral Large", "mistral-large": "Mistral Large",
"aws-claude": "AWS Claude", "aws-claude": "AWS Claude (Sonnet)",
"azure-turbo": "Azure GPT-3.5 Turbo", "azure-turbo": "Azure GPT-3.5 Turbo",
"azure-gpt4": "Azure GPT-4", "azure-gpt4": "Azure GPT-4",
"azure-gpt4-32k": "Azure GPT-4 32k", "azure-gpt4-32k": "Azure GPT-4 32k",

View File

@ -105,7 +105,7 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
res.status(200).json(body); res.status(200).json(body);
}; };
function transformAnthropicChatResponseToAnthropicText( export function transformAnthropicChatResponseToAnthropicText(
anthropicBody: Record<string, any>, anthropicBody: Record<string, any>,
req: Request req: Request
): Record<string, any> { ): Record<string, any> {

View File

@ -16,8 +16,10 @@ import {
ProxyResHandlerWithBody, ProxyResHandlerWithBody,
createOnProxyResHandler, createOnProxyResHandler,
} from "./middleware/response"; } from "./middleware/response";
import { transformAnthropicChatResponseToAnthropicText } from "./anthropic";
const LATEST_AWS_V2_MINOR_VERSION = "1"; const LATEST_AWS_V2_MINOR_VERSION = "1";
const CLAUDE_3_COMPAT_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0";
let modelsCache: any = null; let modelsCache: any = null;
let modelsCacheTime = 0; let modelsCacheTime = 0;
@ -29,10 +31,11 @@ const getModelsResponse = () => {
if (!config.awsCredentials) return { object: "list", data: [] }; if (!config.awsCredentials) return { object: "list", data: [] };
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
const variants = [ const variants = [
"anthropic.claude-v2", "anthropic.claude-v2",
"anthropic.claude-v2:1", "anthropic.claude-v2:1",
"anthropic.claude-3-sonnet-20240229-v1:0" "anthropic.claude-3-sonnet-20240229-v1:0",
]; ];
const models = variants.map((id) => ({ const models = variants.map((id) => ({
@ -73,7 +76,12 @@ const awsResponseHandler: ProxyResHandlerWithBody = async (
if (req.inboundApi === "openai") { if (req.inboundApi === "openai") {
req.log.info("Transforming AWS Claude response to OpenAI format"); req.log.info("Transforming AWS Claude response to OpenAI format");
body = transformAwsResponse(body, req); body = transformAwsTextResponseToOpenAI(body, req);
}
if (req.inboundApi === "anthropic-text") {
req.log.info("Transforming Text AWS Claude response to Chat format");
body = transformAnthropicChatResponseToAnthropicText(body, req);
} }
if (req.tokenizerInfo) { if (req.tokenizerInfo) {
@ -92,7 +100,7 @@ const awsResponseHandler: ProxyResHandlerWithBody = async (
* is only used for non-streaming requests as streaming requests are handled * is only used for non-streaming requests as streaming requests are handled
* on-the-fly. * on-the-fly.
*/ */
function transformAwsResponse( function transformAwsTextResponseToOpenAI(
awsBody: Record<string, any>, awsBody: Record<string, any>,
req: Request req: Request
): Record<string, any> { ): Record<string, any> {
@ -139,18 +147,54 @@ const awsProxy = createQueueMiddleware({
}), }),
}); });
const nativeTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const textToChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes text completion prompts to aws anthropic-chat if they need translation
* (claude-3 based models do not support the old text completion endpoint).
*/
const awsTextCompletionRouter: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
textToChatPreprocessor(req, res, next);
} else {
nativeTextPreprocessor(req, res, next);
}
};
const awsRouter = Router(); const awsRouter = Router();
awsRouter.get("/v1/models", handleModelRequest); awsRouter.get("/v1/models", handleModelRequest);
// Native(ish) Anthropic chat completion endpoint. // Native(ish) Anthropic text completion endpoint.
awsRouter.post("/v1/complete", ipLimiter, awsTextCompletionRouter, awsProxy);
// Native Anthropic chat completion endpoint.
awsRouter.post( awsRouter.post(
"/v1/complete", "/v1/messages",
ipLimiter, ipLimiter,
createPreprocessorMiddleware( createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" }, { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] } { afterTransform: [maybeReassignModel] }
), ),
awsProxy awsProxy
); );
// Temporary force-Claude3 endpoint
awsRouter.post(
"/v1/claude-3/complete",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
{
beforeTransform: [(req) => void (req.body.model = CLAUDE_3_COMPAT_MODEL)],
}
),
awsProxy
);
// OpenAI-to-AWS Anthropic compatibility endpoint. // OpenAI-to-AWS Anthropic compatibility endpoint.
awsRouter.post( awsRouter.post(
"/v1/chat/completions", "/v1/chat/completions",
@ -178,7 +222,8 @@ function maybeReassignModel(req: Request) {
return; return;
} }
const pattern = /^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?$/i; const pattern =
/^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?(-sonnet-?|-opus-?)(\d*)/i;
const match = model.match(pattern); const match = model.match(pattern);
// If there's no match, return the latest v2 model // If there's no match, return the latest v2 model
@ -187,7 +232,9 @@ function maybeReassignModel(req: Request) {
return; return;
} }
const [, , instant, , major, , minor] = match; const instant = match[2];
const major = match[4];
const minor = match[6];
if (instant) { if (instant) {
req.body.model = "anthropic.claude-instant-v1"; req.body.model = "anthropic.claude-instant-v1";
@ -210,6 +257,14 @@ function maybeReassignModel(req: Request) {
return; return;
} }
// AWS currently only supports one v3 model.
const variant = match[8]; // sonnet or opus
const variantVersion = match[9];
if (major === "3") {
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
return;
}
// Fallback to latest v2 model // Fallback to latest v2 model
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`; req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
return; return;

View File

@ -15,15 +15,19 @@ const AMZ_HOST =
/** /**
* Signs an outgoing AWS request with the appropriate headers modifies the * Signs an outgoing AWS request with the appropriate headers modifies the
* request object in place to fix the path. * request object in place to fix the path.
* This happens AFTER request transformation.
*/ */
export const signAwsRequest: RequestPreprocessor = async (req) => { export const signAwsRequest: RequestPreprocessor = async (req) => {
req.key = keyPool.get("anthropic.claude-v2", "aws");
const { model, stream } = req.body; const { model, stream } = req.body;
req.key = keyPool.get(model, "aws");
req.isStreaming = stream === true || stream === "true"; req.isStreaming = stream === true || stream === "true";
let preamble = req.body.prompt.startsWith("\n\nHuman:") ? "" : "\n\nHuman:"; // same as addAnthropicPreamble for non-AWS requests, but has to happen here
req.body.prompt = preamble + req.body.prompt; if (req.outboundApi === "anthropic-text") {
let preamble = req.body.prompt.startsWith("\n\nHuman:") ? "" : "\n\nHuman:";
req.body.prompt = preamble + req.body.prompt;
}
// AWS uses mostly the same parameters as Anthropic, with a few removed params // AWS uses mostly the same parameters as Anthropic, with a few removed params
// and much stricter validation on unused parameters. Rather than treating it // and much stricter validation on unused parameters. Rather than treating it
@ -31,28 +35,27 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
// parameters. // parameters.
// TODO: This should happen in transform-outbound-payload.ts // TODO: This should happen in transform-outbound-payload.ts
let strippedParams: Record<string, unknown>; let strippedParams: Record<string, unknown>;
if (req.inboundApi === "anthropic-chat") { if (req.outboundApi === "anthropic-chat") {
strippedParams = AnthropicV1MessagesSchema strippedParams = AnthropicV1MessagesSchema.pick({
.pick({ messages: true,
messages: true, max_tokens: true,
max_tokens: true, stop_sequences: true,
stop_sequences: true, temperature: true,
temperature: true, top_k: true,
top_k: true, top_p: true,
top_p: true, })
})
.strip() .strip()
.parse(req.body); .parse(req.body);
strippedParams.anthropic_version = "bedrock-2023-05-31";
} else { } else {
strippedParams = AnthropicV1TextSchema strippedParams = AnthropicV1TextSchema.pick({
.pick({ prompt: true,
prompt: true, max_tokens_to_sample: true,
max_tokens_to_sample: true, stop_sequences: true,
stop_sequences: true, temperature: true,
temperature: true, top_k: true,
top_k: true, top_p: true,
top_p: true, })
})
.strip() .strip()
.parse(req.body); .parse(req.body);
} }

View File

@ -332,12 +332,17 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`; errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
break; break;
case "AccessDeniedException": case "AccessDeniedException":
req.log.error( const isModelAccessError = errorPayload.error?.message?.includes(
{ key: req.key?.hash, model: req.body?.model }, `access to the model with the specified model ID`
"Disabling key due to AccessDeniedException when invoking model. If credentials are valid, check IAM permissions."
); );
keyPool.disable(req.key!, "revoked"); if (!isModelAccessError) {
errorPayload.proxy_note = `API key doesn't have access to the requested resource.`; req.log.error(
{ key: req.key?.hash, model: req.body?.model },
"Disabling key due to AccessDeniedException when invoking model. If credentials are valid, check IAM permissions."
);
keyPool.disable(req.key!, "revoked");
}
errorPayload.proxy_note = `API key doesn't have access to the requested resource. Model ID: ${req.body?.model}`;
break; break;
default: default:
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`; errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;

View File

@ -49,7 +49,16 @@ export class SSEStreamAdapter extends Transform {
if (contentType === "application/json" && eventType === "chunk") { if (contentType === "application/json" && eventType === "chunk") {
const { bytes } = JSON.parse(bodyStr); const { bytes } = JSON.parse(bodyStr);
const event = Buffer.from(bytes, "base64").toString("utf8"); const event = Buffer.from(bytes, "base64").toString("utf8");
return ["event: completion", `data: ${event}`].join(`\n`); const eventObj = JSON.parse(event);
if ('completion' in eventObj) {
return ["event: completion", `data: ${event}`].join(`\n`);
} else {
return [
`event: ${eventObj.type}`,
`data: ${event}`,
].join(`\n`);
}
} }
// Intentional fallthrough, as non-JSON events may as well be errors // Intentional fallthrough, as non-JSON events may as well be errors
// noinspection FallThroughInSwitchStatementJS // noinspection FallThroughInSwitchStatementJS

View File

@ -51,6 +51,7 @@ type ModelAggregates = {
overQuota?: number; overQuota?: number;
pozzed?: number; pozzed?: number;
awsLogged?: number; awsLogged?: number;
awsSonnet?: number;
queued: number; queued: number;
queueTime: string; queueTime: string;
tokens: number; tokens: number;
@ -81,7 +82,7 @@ type AnthropicInfo = BaseFamilyInfo & {
prefilledKeys?: number; prefilledKeys?: number;
overQuotaKeys?: number; overQuotaKeys?: number;
}; };
type AwsInfo = BaseFamilyInfo & { privacy?: string }; type AwsInfo = BaseFamilyInfo & { privacy?: string; sonnetKeys?: number };
// prettier-ignore // prettier-ignore
export type ServiceInfo = { export type ServiceInfo = {
@ -133,7 +134,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
}, },
anthropic: { anthropic: {
anthropic: `%BASE%/anthropic`, anthropic: `%BASE%/anthropic`,
"anthropic-claude-3 (temporary compatibility endpoint)": `%BASE%/anthropic/claude-3`, "anthropic-claude-3 (⚠️temporary compatibility endpoint)": `%BASE%/anthropic/claude-3`,
}, },
"google-ai": { "google-ai": {
"google-ai": `%BASE%/google-ai`, "google-ai": `%BASE%/google-ai`,
@ -143,6 +144,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
}, },
aws: { aws: {
aws: `%BASE%/aws/claude`, aws: `%BASE%/aws/claude`,
"aws-claude-3 (⚠temporary compatibility endpoint)": `%BASE%/aws/claude/claude-3`,
}, },
azure: { azure: {
azure: `%BASE%/azure/openai`, azure: `%BASE%/azure/openai`,
@ -372,6 +374,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]); increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
increment(modelStats, `${family}__awsSonnet`, k.sonnetEnabled ? 1 : 0);
// Ignore revoked keys for aws logging stats, but include keys where the // Ignore revoked keys for aws logging stats, but include keys where the
// logging status is unknown. // logging status is unknown.
@ -419,6 +422,7 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0; info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0;
break; break;
case "aws": case "aws":
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
const logged = modelStats.get(`${family}__awsLogged`) || 0; const logged = modelStats.get(`${family}__awsLogged`) || 0;
if (logged > 0) { if (logged > 0) {
info.privacy = config.allowAwsLogging info.privacy = config.allowAwsLogging

View File

@ -15,7 +15,10 @@ const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`; `https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
const POST_INVOKE_MODEL_URL = (region: string, model: string) => const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`; `https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
const TEST_PROMPT = "\n\nHuman:\n\nAssistant:"; const TEST_MESSAGES = [
{ role: "user", content: "Hi!" },
{ role: "assistant", content: "Hello!" },
];
type AwsError = { error: {} }; type AwsError = { error: {} };
@ -47,8 +50,10 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
const modelChecks: Promise<unknown>[] = []; const modelChecks: Promise<unknown>[] = [];
const isInitialCheck = !key.lastChecked; const isInitialCheck = !key.lastChecked;
if (isInitialCheck) { if (isInitialCheck) {
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
modelChecks.push(this.invokeModel("anthropic.claude-v2:1", key)); modelChecks.push(this.invokeModel("anthropic.claude-v2:1", key));
modelChecks.push(
this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key)
);
} }
await Promise.all(modelChecks); await Promise.all(modelChecks);
@ -128,12 +133,18 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
const creds = AwsKeyChecker.getCredentialsFromKey(key); const creds = AwsKeyChecker.getCredentialsFromKey(key);
// This is not a valid invocation payload, but a 400 response indicates that // This is not a valid invocation payload, but a 400 response indicates that
// the principal at least has permission to invoke the model. // the principal at least has permission to invoke the model.
const payload = { max_tokens_to_sample: -1, prompt: TEST_PROMPT }; // A 403 response indicates that the model is not accessible -- if none of
// the models are accessible, the key is effectively disabled.
const payload = {
max_tokens: -1,
messages: TEST_MESSAGES,
anthropic_version: "bedrock-2023-05-31",
};
const config: AxiosRequestConfig = { const config: AxiosRequestConfig = {
method: "POST", method: "POST",
url: POST_INVOKE_MODEL_URL(creds.region, model), url: POST_INVOKE_MODEL_URL(creds.region, model),
data: payload, data: payload,
validateStatus: (status) => status === 400, validateStatus: (status) => status === 400 || status === 403,
}; };
config.headers = new AxiosHeaders({ config.headers = new AxiosHeaders({
"content-type": "application/json", "content-type": "application/json",
@ -145,10 +156,20 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0]; const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
const errorMessage = data?.message; const errorMessage = data?.message;
// We only allow one type of 403 error, and we only allow it for one model.
if (status === 403 && errorMessage?.match(/access to the model with the specified model ID/)) {
this.log.warn(
{ key: key.hash, errorType, data, status, model },
"Key does not have access to Claude 3 Sonnet."
);
this.updateKey(key.hash, { sonnetEnabled: false });
return;
}
// We're looking for a specific error type and message here // We're looking for a specific error type and message here
// "ValidationException" // "ValidationException"
const correctErrorType = errorType === "ValidationException"; const correctErrorType = errorType === "ValidationException";
const correctErrorMessage = errorMessage?.match(/max_tokens_to_sample/); const correctErrorMessage = errorMessage?.match(/max_tokens/);
if (!correctErrorType || !correctErrorMessage) { if (!correctErrorType || !correctErrorMessage) {
throw new AxiosError( throw new AxiosError(
`Unexpected error when invoking model ${model}: ${errorMessage}`, `Unexpected error when invoking model ${model}: ${errorMessage}`,

View File

@ -29,6 +29,7 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
* set. * set.
*/ */
awsLoggingStatus: "unknown" | "disabled" | "enabled"; awsLoggingStatus: "unknown" | "disabled" | "enabled";
sonnetEnabled: boolean;
} }
/** /**
@ -78,6 +79,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
.digest("hex") .digest("hex")
.slice(0, 8)}`, .slice(0, 8)}`,
lastChecked: 0, lastChecked: 0,
sonnetEnabled: true,
["aws-claudeTokens"]: 0, ["aws-claudeTokens"]: 0,
}; };
this.keys.push(newKey); this.keys.push(newKey);
@ -96,13 +98,20 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
} }
public get(_model: AwsBedrockModel) { public get(model: AwsBedrockModel) {
const availableKeys = this.keys.filter((k) => { const availableKeys = this.keys.filter((k) => {
const isNotLogged = k.awsLoggingStatus === "disabled"; const isNotLogged = k.awsLoggingStatus === "disabled";
return !k.isDisabled && (isNotLogged || config.allowAwsLogging); const needsSonnet = model.includes("sonnet");
return (
!k.isDisabled &&
(isNotLogged || config.allowAwsLogging) &&
(k.sonnetEnabled || !needsSonnet)
);
}); });
if (availableKeys.length === 0) { if (availableKeys.length === 0) {
throw new Error("No AWS Bedrock keys available"); throw new Error(
"No keys available for this model. If you are requesting Sonnet, use Claude-2 instead."
);
} }
// (largely copied from the OpenAI provider, without trial key support) // (largely copied from the OpenAI provider, without trial key support)