diff --git a/.env.example b/.env.example index 9cae73d..35f7890 100644 --- a/.env.example +++ b/.env.example @@ -40,11 +40,11 @@ NODE_ENV=production # Which model types users are allowed to access. # The following model families are recognized: -# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-dall-e +# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-dall-e # By default, all models are allowed except for 'dall-e' / 'azure-dall-e'. # To allow DALL-E image generation, uncomment the line below and add 'dall-e' or # 'azure-dall-e' to the list of allowed model families. -# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo +# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo # URLs from which requests will be blocked. # BLOCKED_ORIGINS=reddit.com,9gag.com diff --git a/src/config.ts b/src/config.ts index f5fb7e0..80ff8b5 100644 --- a/src/config.ts +++ b/src/config.ts @@ -312,6 +312,7 @@ export const config: Config = { "mistral-medium", "mistral-large", "aws-claude", + "aws-claude-opus", "azure-turbo", "azure-gpt4", "azure-gpt4-turbo", diff --git a/src/info-page.ts b/src/info-page.ts index 13f9d0a..a04cb77 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -25,6 +25,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { "mistral-medium": "Mistral Medium", "mistral-large": "Mistral Large", "aws-claude": "AWS Claude (Sonnet)", + "aws-claude-opus": "AWS Claude (Opus)", "azure-turbo": "Azure GPT-3.5 Turbo", "azure-gpt4": "Azure GPT-4", "azure-gpt4-32k": "Azure GPT-4 32k", diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index 14d3a33..094eec5 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -257,10 +257,16 @@ function maybeReassignModel(req: Request) { } // AWS currently only supports one v3 model. - const variant = match[8]; // sonnet or opus + const variant = match[8]; // sonnet, opus, or haiku const variantVersion = match[9]; if (major === "3") { - req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0"; + if (variant.includes("opus")) { + req.body.model = "anthropic.claude-3-opus-20240229-v1:0"; + } else if (variant.includes("haiku")) { + req.body.model = "anthropic.claude-3-haiku-20240307-v1:0"; + } else { + req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0"; + } return; } diff --git a/src/proxy/middleware/request/preprocessors/validate-context-size.ts b/src/proxy/middleware/request/preprocessors/validate-context-size.ts index 0db25b3..3f911c2 100644 --- a/src/proxy/middleware/request/preprocessors/validate-context-size.ts +++ b/src/proxy/middleware/request/preprocessors/validate-context-size.ts @@ -63,7 +63,7 @@ export const validateContextSize: RequestPreprocessor = async (req) => { } else if (model.match(/^gpt-4(-\d{4})?-vision(-preview)?$/)) { modelMax = 131072; } else if (model.match(/gpt-3.5-turbo/)) { - modelMax = 4096; + modelMax = 16384; } else if (model.match(/gpt-4-32k/)) { modelMax = 32768; } else if (model.match(/gpt-4/)) { @@ -82,7 +82,7 @@ export const validateContextSize: RequestPreprocessor = async (req) => { modelMax = GOOGLE_AI_MAX_CONTEXT; } else if (model.match(/^mistral-(tiny|small|medium)$/)) { modelMax = MISTRAL_AI_MAX_CONTENT; - } else if (model.match(/^anthropic\.claude-3-sonnet/)) { + } else if (model.match(/^anthropic\.claude-3/)) { modelMax = 200000; } else if (model.match(/^anthropic\.claude-v2:\d/)) { modelMax = 200000; diff --git a/src/service-info.ts b/src/service-info.ts index 928b93d..7adf0fd 100644 --- a/src/service-info.ts +++ b/src/service-info.ts @@ -387,21 +387,22 @@ function addKeyToAggregates(k: KeyPoolKey) { } case "aws": { if (!keyIsAwsKey(k)) throw new Error("Invalid key type"); - const family = "aws-claude"; - sumTokens += k["aws-claudeTokens"]; - sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]); - increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]); - increment(modelStats, `${family}__awsSonnet`, k.sonnetEnabled ? 1 : 0); - increment(modelStats, `${family}__awsHaiku`, k.haikuEnabled ? 1 : 0); + k.modelFamilies.forEach((f) => { + const tokens = k[`${f}Tokens`]; + sumTokens += tokens; + sumCost += getTokenCostUsd(f, tokens); + increment(modelStats, `${f}__tokens`, tokens); + increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); + increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); + }); + increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0); + increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0); // Ignore revoked keys for aws logging stats, but include keys where the // logging status is unknown. const countAsLogged = k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled"; - increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0); - + increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0); break; } default: diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index 6b2a2f2..98bbd9c 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -2,7 +2,7 @@ import crypto from "crypto"; import { Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import type { AwsBedrockModelFamily } from "../../models"; +import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models"; import { AwsKeyChecker } from "./checker"; import { PaymentRequiredError } from "../../errors"; @@ -61,7 +61,7 @@ export class AwsBedrockKeyProvider implements KeyProvider { const newKey: AwsBedrockKey = { key, service: this.service, - modelFamilies: ["aws-claude"], + modelFamilies: ["aws-claude", "aws-claude-opus"], isDisabled: false, isRevoked: false, promptCount: 0, @@ -78,6 +78,7 @@ export class AwsBedrockKeyProvider implements KeyProvider { sonnetEnabled: true, haikuEnabled: false, ["aws-claudeTokens"]: 0, + ["aws-claude-opusTokens"]: 0, }; this.keys.push(newKey); } @@ -157,11 +158,11 @@ export class AwsBedrockKeyProvider implements KeyProvider { return this.keys.filter((k) => !k.isDisabled).length; } - public incrementUsage(hash: string, _model: string, tokens: number) { + public incrementUsage(hash: string, model: string, tokens: number) { const key = this.keys.find((k) => k.hash === hash); if (!key) return; key.promptCount++; - key["aws-claudeTokens"] += tokens; + key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens; } public getLockoutPeriod() { diff --git a/src/shared/models.ts b/src/shared/models.ts index 14a7fc2..d87713a 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -29,7 +29,7 @@ export type MistralAIModelFamily = | "mistral-small" | "mistral-medium" | "mistral-large"; -export type AwsBedrockModelFamily = "aws-claude"; +export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus"; export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`; export type ModelFamily = | OpenAIModelFamily @@ -55,6 +55,7 @@ export const MODEL_FAMILIES = (( "mistral-medium", "mistral-large", "aws-claude", + "aws-claude-opus", "azure-turbo", "azure-gpt4", "azure-gpt4-32k", @@ -98,6 +99,7 @@ export const MODEL_FAMILY_SERVICE: { claude: "anthropic", "claude-opus": "anthropic", "aws-claude": "aws", + "aws-claude-opus": "aws", "azure-turbo": "azure", "azure-gpt4": "azure", "azure-gpt4-32k": "azure", @@ -150,8 +152,8 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily { } } -export function getAwsBedrockModelFamily(model: string): ModelFamily { - if (model.includes("opus")) return "claude-opus"; +export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily { + if (model.includes("opus")) return "aws-claude-opus"; return "aws-claude"; } diff --git a/src/shared/stats.ts b/src/shared/stats.ts index e63b156..073c4a1 100644 --- a/src/shared/stats.ts +++ b/src/shared/stats.ts @@ -29,6 +29,7 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) { case "claude": cost = 0.000008; break; + case "aws-claude-opus": case "claude-opus": cost = 0.000015; break;