adds aws opus maybe, idk cannot test

This commit is contained in:
nai-degen 2024-04-16 11:33:44 -05:00
parent 9445110727
commit c0cd2c7549
9 changed files with 36 additions and 23 deletions

View File

@ -40,11 +40,11 @@ NODE_ENV=production
# Which model types users are allowed to access.
# The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-dall-e
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-dall-e
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
# 'azure-dall-e' to the list of allowed model families.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com

View File

@ -312,6 +312,7 @@ export const config: Config = {
"mistral-medium",
"mistral-large",
"aws-claude",
"aws-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-turbo",

View File

@ -25,6 +25,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"mistral-medium": "Mistral Medium",
"mistral-large": "Mistral Large",
"aws-claude": "AWS Claude (Sonnet)",
"aws-claude-opus": "AWS Claude (Opus)",
"azure-turbo": "Azure GPT-3.5 Turbo",
"azure-gpt4": "Azure GPT-4",
"azure-gpt4-32k": "Azure GPT-4 32k",

View File

@ -257,10 +257,16 @@ function maybeReassignModel(req: Request) {
}
// AWS currently only supports one v3 model.
const variant = match[8]; // sonnet or opus
const variant = match[8]; // sonnet, opus, or haiku
const variantVersion = match[9];
if (major === "3") {
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
if (variant.includes("opus")) {
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
} else if (variant.includes("haiku")) {
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
} else {
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
}
return;
}

View File

@ -63,7 +63,7 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
} else if (model.match(/^gpt-4(-\d{4})?-vision(-preview)?$/)) {
modelMax = 131072;
} else if (model.match(/gpt-3.5-turbo/)) {
modelMax = 4096;
modelMax = 16384;
} else if (model.match(/gpt-4-32k/)) {
modelMax = 32768;
} else if (model.match(/gpt-4/)) {
@ -82,7 +82,7 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = GOOGLE_AI_MAX_CONTEXT;
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
modelMax = MISTRAL_AI_MAX_CONTENT;
} else if (model.match(/^anthropic\.claude-3-sonnet/)) {
} else if (model.match(/^anthropic\.claude-3/)) {
modelMax = 200000;
} else if (model.match(/^anthropic\.claude-v2:\d/)) {
modelMax = 200000;

View File

@ -387,21 +387,22 @@ function addKeyToAggregates(k: KeyPoolKey) {
}
case "aws": {
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
const family = "aws-claude";
sumTokens += k["aws-claudeTokens"];
sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
increment(modelStats, `${family}__awsSonnet`, k.sonnetEnabled ? 1 : 0);
increment(modelStats, `${family}__awsHaiku`, k.haikuEnabled ? 1 : 0);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
// Ignore revoked keys for aws logging stats, but include keys where the
// logging status is unknown.
const countAsLogged =
k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled";
increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0);
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
break;
}
default:

View File

@ -2,7 +2,7 @@ import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AwsBedrockModelFamily } from "../../models";
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
import { AwsKeyChecker } from "./checker";
import { PaymentRequiredError } from "../../errors";
@ -61,7 +61,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
const newKey: AwsBedrockKey = {
key,
service: this.service,
modelFamilies: ["aws-claude"],
modelFamilies: ["aws-claude", "aws-claude-opus"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
@ -78,6 +78,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
sonnetEnabled: true,
haikuEnabled: false,
["aws-claudeTokens"]: 0,
["aws-claude-opusTokens"]: 0,
};
this.keys.push(newKey);
}
@ -157,11 +158,11 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, _model: string, tokens: number) {
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
key["aws-claudeTokens"] += tokens;
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {

View File

@ -29,7 +29,7 @@ export type MistralAIModelFamily =
| "mistral-small"
| "mistral-medium"
| "mistral-large";
export type AwsBedrockModelFamily = "aws-claude";
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
export type ModelFamily =
| OpenAIModelFamily
@ -55,6 +55,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"mistral-medium",
"mistral-large",
"aws-claude",
"aws-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
@ -98,6 +99,7 @@ export const MODEL_FAMILY_SERVICE: {
claude: "anthropic",
"claude-opus": "anthropic",
"aws-claude": "aws",
"aws-claude-opus": "aws",
"azure-turbo": "azure",
"azure-gpt4": "azure",
"azure-gpt4-32k": "azure",
@ -150,8 +152,8 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
}
}
export function getAwsBedrockModelFamily(model: string): ModelFamily {
if (model.includes("opus")) return "claude-opus";
export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
if (model.includes("opus")) return "aws-claude-opus";
return "aws-claude";
}

View File

@ -29,6 +29,7 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
case "claude":
cost = 0.000008;
break;
case "aws-claude-opus":
case "claude-opus":
cost = 0.000015;
break;