Merge GCP Vertex AI implementation from cg-dot/oai-reverse-proxy (khanon/oai-reverse-proxy!72)
This commit is contained in:
parent
29ed07492e
commit
0c936e97fe
15
.env.example
15
.env.example
|
@ -40,15 +40,21 @@ NODE_ENV=production
|
|||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
|
||||
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus
|
||||
# | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small
|
||||
# | mistral-medium | mistral-large | aws-claude | aws-claude-opus | gcp-claude
|
||||
# | gcp-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k
|
||||
# | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
|
||||
|
||||
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
|
||||
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
|
||||
# 'azure-dall-e' to the list of allowed model families.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,gcp-claude,gcp-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||
|
||||
# Which services can be used to process prompts containing images via multimodal
|
||||
# models. The following services are recognized:
|
||||
# openai | anthropic | aws | azure | google-ai | mistral-ai
|
||||
# openai | anthropic | aws | gcp | azure | google-ai | mistral-ai
|
||||
# Do not enable this feature unless all users are trusted, as you will be liable
|
||||
# for any user-submitted images containing illegal content.
|
||||
# By default, no image services are allowed and image prompts are rejected.
|
||||
|
@ -118,6 +124,7 @@ NODE_ENV=production
|
|||
# TOKEN_QUOTA_CLAUDE=0
|
||||
# TOKEN_QUOTA_GEMINI_PRO=0
|
||||
# TOKEN_QUOTA_AWS_CLAUDE=0
|
||||
# TOKEN_QUOTA_GCP_CLAUDE=0
|
||||
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
|
||||
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
|
||||
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
||||
|
@ -142,6 +149,7 @@ NODE_ENV=production
|
|||
|
||||
# You can add multiple API keys by separating them with a comma.
|
||||
# For AWS credentials, separate the access key ID, secret key, and region with a colon.
|
||||
# For GCP credentials, separate the project ID, client email, region, and private key with a colon.
|
||||
OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
|
@ -149,6 +157,7 @@ GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
|||
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
||||
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
|
||||
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
|
||||
GCP_CREDENTIALS=project-id:client-email:region:private-key
|
||||
|
||||
# With proxy_key gatekeeper, the password users must provide to access the API.
|
||||
# PROXY_KEY=your-secret-key
|
||||
|
|
|
@ -20,6 +20,7 @@ This project allows you to run a reverse proxy server for various LLM APIs.
|
|||
- [x] [OpenAI](https://openai.com/)
|
||||
- [x] [Anthropic](https://www.anthropic.com/)
|
||||
- [x] [AWS Bedrock](https://aws.amazon.com/bedrock/)
|
||||
- [x] [Vertex AI (GCP)](https://cloud.google.com/vertex-ai/)
|
||||
- [x] [Google MakerSuite/Gemini API](https://ai.google.dev/)
|
||||
- [x] [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
|
||||
- [x] Translation from OpenAI-formatted prompts to any other API, including streaming responses
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Configuring the proxy for Vertex AI (GCP)
|
||||
|
||||
The proxy supports GCP models via the `/proxy/gcp/claude` endpoint. There are a few extra steps necessary to use GCP compared to the other supported APIs.
|
||||
|
||||
- [Setting keys](#setting-keys)
|
||||
- [Setup Vertex AI](#setup-vertex-ai)
|
||||
- [Supported model IDs](#supported-model-ids)
|
||||
|
||||
## Setting keys
|
||||
|
||||
Use the `GCP_CREDENTIALS` environment variable to set the GCP API keys.
|
||||
|
||||
Like other APIs, you can provide multiple keys separated by commas. Each GCP key, however, is a set of credentials including the project id, client email, region and private key. These are separated by a colon (`:`).
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
GCP_CREDENTIALS=my-first-project:xxx@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,my-first-project2:xxx2@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----
|
||||
```
|
||||
|
||||
## Setup Vertex AI
|
||||
1. Go to [https://cloud.google.com/vertex-ai](https://cloud.google.com/vertex-ai) and sign up for a GCP account. ($150 free credits without credit card or $300 free credits with credit card, credits expire in 90 days)
|
||||
2. Go to [https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com](https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com) to enable Vertex AI API.
|
||||
3. Go to [https://console.cloud.google.com/vertex-ai](https://console.cloud.google.com/vertex-ai) and navigate to Model Garden to apply for access to the Claude models.
|
||||
4. Create a [Service Account](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1) , and make sure to grant the role of "Vertex AI User" or "Vertex AI Administrator".
|
||||
5. On the service account page you just created, create a new key and select "JSON". The JSON file will be downloaded automatically.
|
||||
6. The required credential is in the JSON file you just downloaded.
|
||||
|
||||
## Supported model IDs
|
||||
Users can send these model IDs to the proxy to invoke the corresponding models.
|
||||
- **Claude**
|
||||
- `claude-3-haiku@20240307`
|
||||
- `claude-3-sonnet@20240229`
|
||||
- `claude-3-opus@20240229`
|
||||
- `claude-3-5-sonnet@20240620`
|
|
@ -230,6 +230,39 @@ Content-Type: application/json
|
|||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / GCP Claude -- Native Completion
|
||||
POST {{proxy-host}}/proxy/gcp/claude/v1/complete
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
anthropic-version: 2023-01-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v2",
|
||||
"max_tokens_to_sample": 10,
|
||||
"temperature": 0,
|
||||
"stream": true,
|
||||
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / GCP Claude -- OpenAI-to-Anthropic API Translation
|
||||
POST {{proxy-host}}/proxy/gcp/claude/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 50,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is genshin impact?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Azure OpenAI -- Native Chat Completions
|
||||
POST {{proxy-host}}/proxy/azure/openai/chat/completions
|
||||
|
|
|
@ -51,6 +51,8 @@ function getRandomModelFamily() {
|
|||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"gcp-claude",
|
||||
"gcp-claude-opus",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
|
|
|
@ -268,7 +268,7 @@ router.post("/maintenance", (req, res) => {
|
|||
let flash = { type: "", message: "" };
|
||||
switch (action) {
|
||||
case "recheck": {
|
||||
const checkable: LLMService[] = ["openai", "anthropic", "aws", "azure"];
|
||||
const checkable: LLMService[] = ["openai", "anthropic", "aws", "gcp","azure"];
|
||||
checkable.forEach((s) => keyPool.recheck(s));
|
||||
const keyCount = keyPool
|
||||
.list()
|
||||
|
|
|
@ -45,6 +45,13 @@ type Config = {
|
|||
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
|
||||
*/
|
||||
awsCredentials?: string;
|
||||
/**
|
||||
* Comma-delimited list of GCP credentials. Each credential item should be a
|
||||
* colon-delimited list of access key, secret key, and GCP region.
|
||||
*
|
||||
* @example `GCP_CREDENTIALS=project1:1@1.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,project2:2@2.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----`
|
||||
*/
|
||||
gcpCredentials?: string;
|
||||
/**
|
||||
* Comma-delimited list of Azure OpenAI credentials. Each credential item
|
||||
* should be a colon-delimited list of Azure resource name, deployment ID, and
|
||||
|
@ -349,7 +356,7 @@ type Config = {
|
|||
*
|
||||
* Defaults to no services, meaning image prompts are disabled. Use a comma-
|
||||
* separated list. Available services are:
|
||||
* openai,anthropic,google-ai,mistral-ai,aws,azure
|
||||
* openai,anthropic,google-ai,mistral-ai,aws,gcp,azure
|
||||
*/
|
||||
allowedVisionServices: LLMService[];
|
||||
/**
|
||||
|
@ -383,6 +390,7 @@ export const config: Config = {
|
|||
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
||||
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
|
||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||
gcpCredentials: getEnvWithDefault("GCP_CREDENTIALS", ""),
|
||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
||||
|
@ -437,6 +445,8 @@ export const config: Config = {
|
|||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"gcp-claude",
|
||||
"gcp-claude-opus",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
|
@ -511,6 +521,7 @@ function generateSigningKey() {
|
|||
config.googleAIKey,
|
||||
config.mistralAIKey,
|
||||
config.awsCredentials,
|
||||
config.gcpCredentials,
|
||||
config.azureCredentials,
|
||||
];
|
||||
if (secrets.filter((s) => s).length === 0) {
|
||||
|
@ -648,6 +659,7 @@ export const OMITTED_KEYS = [
|
|||
"googleAIKey",
|
||||
"mistralAIKey",
|
||||
"awsCredentials",
|
||||
"gcpCredentials",
|
||||
"azureCredentials",
|
||||
"proxyKey",
|
||||
"adminKey",
|
||||
|
@ -738,6 +750,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
|||
"ANTHROPIC_KEY",
|
||||
"GOOGLE_AI_KEY",
|
||||
"AWS_CREDENTIALS",
|
||||
"GCP_CREDENTIALS",
|
||||
"AZURE_CREDENTIALS",
|
||||
].includes(String(env))
|
||||
) {
|
||||
|
|
|
@ -29,6 +29,8 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
|||
"mistral-large": "Mistral Large",
|
||||
"aws-claude": "AWS Claude (Sonnet)",
|
||||
"aws-claude-opus": "AWS Claude (Opus)",
|
||||
"gcp-claude": "GCP Claude (Sonnet)",
|
||||
"gcp-claude-opus": "GCP Claude (Opus)",
|
||||
"azure-turbo": "Azure GPT-3.5 Turbo",
|
||||
"azure-gpt4": "Azure GPT-4",
|
||||
"azure-gpt4-32k": "Azure GPT-4 32k",
|
||||
|
|
|
@ -129,7 +129,7 @@ export function transformAnthropicChatResponseToAnthropicText(
|
|||
* is only used for non-streaming requests as streaming requests are handled
|
||||
* on-the-fly.
|
||||
*/
|
||||
function transformAnthropicTextResponseToOpenAI(
|
||||
export function transformAnthropicTextResponseToOpenAI(
|
||||
anthropicBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
import { Request, RequestHandler, Response, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
createPreprocessorMiddleware,
|
||||
signGcpRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
} from "./middleware/response";
|
||||
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
|
||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||
|
||||
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
if (!config.gcpCredentials) return { object: "list", data: [] };
|
||||
|
||||
// https://docs.anthropic.com/en/docs/about-claude/models
|
||||
const variants = [
|
||||
"claude-3-haiku@20240307",
|
||||
"claude-3-sonnet@20240229",
|
||||
"claude-3-opus@20240229",
|
||||
"claude-3-5-sonnet@20240620",
|
||||
];
|
||||
|
||||
const models = variants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "anthropic",
|
||||
permission: [],
|
||||
root: "claude",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const gcpResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
switch (`${req.inboundApi}<-${req.outboundApi}`) {
|
||||
case "openai<-anthropic-chat":
|
||||
req.log.info("Transforming Anthropic Chat back to OpenAI format");
|
||||
newBody = transformAnthropicChatResponseToOpenAI(body);
|
||||
break;
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
const gcpProxy = createQueueMiddleware({
|
||||
beforeProxy: signGcpRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([gcpResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const oaiToChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-chat", service: "gcp" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
|
||||
* or the new Claude chat completion endpoint, based on the requested model.
|
||||
*/
|
||||
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||
oaiToChatPreprocessor(req, res, next);
|
||||
};
|
||||
|
||||
const gcpRouter = Router();
|
||||
gcpRouter.get("/v1/models", handleModelRequest);
|
||||
// Native Anthropic chat completion endpoint.
|
||||
gcpRouter.post(
|
||||
"/v1/messages",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "gcp" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
gcpProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-GCP Anthropic compatibility endpoint.
|
||||
gcpRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
preprocessOpenAICompatRequest,
|
||||
gcpProxy
|
||||
);
|
||||
|
||||
/**
|
||||
* Tries to deal with:
|
||||
* - frontends sending GCP model names even when they want to use the OpenAI-
|
||||
* compatible endpoint
|
||||
* - frontends sending Anthropic model names that GCP doesn't recognize
|
||||
* - frontends sending OpenAI model names because they expect the proxy to
|
||||
* translate them
|
||||
*
|
||||
* If client sends GCP model ID it will be used verbatim. Otherwise, various
|
||||
* strategies are used to try to map a non-GCP model name to GCP model ID.
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If it looks like an GCP model, use it as-is
|
||||
// if (model.includes("anthropic.claude")) {
|
||||
if (model.startsWith("claude-") && model.includes("@")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Anthropic model names can look like:
|
||||
// - claude-v1
|
||||
// - claude-2.1
|
||||
// - claude-3-5-sonnet-20240620-v1:0
|
||||
const pattern =
|
||||
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
|
||||
const match = model.match(pattern);
|
||||
|
||||
// If there's no match, fallback to Claude3 Sonnet as it is most likely to be
|
||||
// available on GCP.
|
||||
if (!match) {
|
||||
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
|
||||
|
||||
const ver = minor ? `${major}.${minor}` : major;
|
||||
switch (ver) {
|
||||
case "3":
|
||||
case "3.0":
|
||||
if (name.includes("opus")) {
|
||||
req.body.model = "claude-3-opus@20240229";
|
||||
} else if (name.includes("haiku")) {
|
||||
req.body.model = "claude-3-haiku@20240307";
|
||||
} else {
|
||||
req.body.model = "claude-3-sonnet@20240229";
|
||||
}
|
||||
return;
|
||||
case "3.5":
|
||||
req.body.model = "claude-3-5-sonnet@20240620";
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to Claude3 Sonnet
|
||||
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
export const gcp = gcpRouter;
|
|
@ -15,6 +15,7 @@ export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
|
|||
export { languageFilter } from "./preprocessors/language-filter";
|
||||
export { setApiFormat } from "./preprocessors/set-api-format";
|
||||
export { signAwsRequest } from "./preprocessors/sign-aws-request";
|
||||
export { signGcpRequest } from "./preprocessors/sign-vertex-ai-request";
|
||||
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
|
||||
export { validateContextSize } from "./preprocessors/validate-context-size";
|
||||
export { validateVision } from "./preprocessors/validate-vision";
|
||||
|
|
|
@ -83,6 +83,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
|||
proxyReq.setHeader("api-key", azureKey);
|
||||
break;
|
||||
case "aws":
|
||||
case "gcp":
|
||||
case "google-ai":
|
||||
throw new Error("add-key should not be used for this service.");
|
||||
default:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import type { HPMRequestCallback } from "../index";
|
||||
|
||||
/**
|
||||
* For AWS/Azure/Google requests, the body is signed earlier in the request
|
||||
* For AWS/GCP/Azure/Google requests, the body is signed earlier in the request
|
||||
* pipeline, before the proxy middleware. This function just assigns the path
|
||||
* and headers to the proxy request.
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
import express from "express";
|
||||
import crypto from "crypto";
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas";
|
||||
|
||||
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
||||
|
||||
export const signGcpRequest: RequestPreprocessor = async (req) => {
|
||||
const serviceValid = req.service === "gcp";
|
||||
if (!serviceValid) {
|
||||
throw new Error("addVertexAIKey called on invalid request");
|
||||
}
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
const { model, stream } = req.body;
|
||||
req.key = keyPool.get(model, "gcp");
|
||||
|
||||
req.log.info({ key: req.key.hash, model }, "Assigned GCP key to request");
|
||||
|
||||
req.isStreaming = String(stream) === "true";
|
||||
|
||||
// TODO: This should happen in transform-outbound-payload.ts
|
||||
// TODO: Support tools
|
||||
let strippedParams: Record<string, unknown>;
|
||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||
messages: true,
|
||||
system: true,
|
||||
max_tokens: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
stream: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "vertex-2023-10-16";
|
||||
|
||||
const [accessToken, credential] = await getAccessToken(req);
|
||||
|
||||
const host = GCP_HOST.replace("%REGION%", credential.region);
|
||||
// GCP doesn't use the anthropic-version header, but we set it to ensure the
|
||||
// stream adapter selects the correct transformer.
|
||||
req.headers["anthropic-version"] = "2023-06-01";
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: host,
|
||||
path: `/v1/projects/${credential.projectId}/locations/${credential.region}/publishers/anthropic/models/${model}:streamRawPredict`,
|
||||
headers: {
|
||||
["host"]: host,
|
||||
["content-type"]: "application/json",
|
||||
["authorization"]: `Bearer ${accessToken}`,
|
||||
},
|
||||
body: JSON.stringify(strippedParams),
|
||||
};
|
||||
};
|
||||
|
||||
async function getAccessToken(
|
||||
req: express.Request
|
||||
): Promise<[string, Credential]> {
|
||||
// TODO: access token caching to reduce latency
|
||||
const credential = getCredentialParts(req);
|
||||
const signedJWT = await createSignedJWT(
|
||||
credential.clientEmail,
|
||||
credential.privateKey
|
||||
);
|
||||
const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT);
|
||||
if (accessToken === null) {
|
||||
req.log.warn(
|
||||
{ key: req.key!.hash, jwtError },
|
||||
"Unable to get the access token"
|
||||
);
|
||||
throw new Error("The access token is invalid.");
|
||||
}
|
||||
return [accessToken, credential];
|
||||
}
|
||||
|
||||
async function createSignedJWT(email: string, pkey: string): Promise<string> {
|
||||
let cryptoKey = await crypto.subtle.importKey(
|
||||
"pkcs8",
|
||||
str2ab(atob(pkey)),
|
||||
{
|
||||
name: "RSASSA-PKCS1-v1_5",
|
||||
hash: { name: "SHA-256" },
|
||||
},
|
||||
false,
|
||||
["sign"]
|
||||
);
|
||||
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const issued = Math.floor(Date.now() / 1000);
|
||||
const expires = issued + 600;
|
||||
|
||||
const header = {
|
||||
alg: "RS256",
|
||||
typ: "JWT",
|
||||
};
|
||||
|
||||
const payload = {
|
||||
iss: email,
|
||||
aud: authUrl,
|
||||
iat: issued,
|
||||
exp: expires,
|
||||
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||
};
|
||||
|
||||
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
|
||||
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
|
||||
|
||||
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||
|
||||
const signature = await crypto.subtle.sign(
|
||||
"RSASSA-PKCS1-v1_5",
|
||||
cryptoKey,
|
||||
str2ab(unsignedToken)
|
||||
);
|
||||
|
||||
const encodedSignature = urlSafeBase64Encode(signature);
|
||||
return `${unsignedToken}.${encodedSignature}`;
|
||||
}
|
||||
|
||||
async function exchangeJwtForAccessToken(
|
||||
signed_jwt: string
|
||||
): Promise<[string | null, string]> {
|
||||
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const params = {
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
assertion: signed_jwt,
|
||||
};
|
||||
|
||||
const r = await fetch(auth_url, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: Object.entries(params)
|
||||
.map(([k, v]) => `${k}=${v}`)
|
||||
.join("&"),
|
||||
}).then((res) => res.json());
|
||||
|
||||
if (r.access_token) {
|
||||
return [r.access_token, ""];
|
||||
}
|
||||
|
||||
return [null, JSON.stringify(r)];
|
||||
}
|
||||
|
||||
function str2ab(str: string): ArrayBuffer {
|
||||
const buffer = new ArrayBuffer(str.length);
|
||||
const bufferView = new Uint8Array(buffer);
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
bufferView[i] = str.charCodeAt(i);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||
let base64: string;
|
||||
if (typeof data === "string") {
|
||||
base64 = btoa(
|
||||
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
|
||||
String.fromCharCode(parseInt("0x" + p1, 16))
|
||||
)
|
||||
);
|
||||
} else {
|
||||
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||
}
|
||||
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||
}
|
||||
|
||||
type Credential = {
|
||||
projectId: string;
|
||||
clientEmail: string;
|
||||
region: string;
|
||||
privateKey: string;
|
||||
};
|
||||
|
||||
function getCredentialParts(req: express.Request): Credential {
|
||||
const [projectId, clientEmail, region, rawPrivateKey] =
|
||||
req.key!.key.split(":");
|
||||
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||
req.log.error(
|
||||
{ key: req.key!.hash },
|
||||
"GCP_CREDENTIALS isn't correctly formatted; refer to the docs"
|
||||
);
|
||||
throw new Error("The key assigned to this request is invalid.");
|
||||
}
|
||||
|
||||
const privateKey = rawPrivateKey
|
||||
.replace(
|
||||
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
|
||||
""
|
||||
)
|
||||
.trim();
|
||||
|
||||
return { projectId, clientEmail, region, privateKey };
|
||||
}
|
|
@ -186,6 +186,13 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
throw new HttpError(statusCode, parseError.message);
|
||||
}
|
||||
|
||||
const service = req.key!.service;
|
||||
if (service === "gcp") {
|
||||
if (Array.isArray(errorPayload)) {
|
||||
errorPayload = errorPayload[0];
|
||||
}
|
||||
}
|
||||
|
||||
const errorType =
|
||||
errorPayload.error?.code ||
|
||||
errorPayload.error?.type ||
|
||||
|
@ -199,11 +206,15 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
// TODO: split upstream error handling into separate modules for each service,
|
||||
// this is out of control.
|
||||
|
||||
const service = req.key!.service;
|
||||
if (service === "aws") {
|
||||
// Try to standardize the error format for AWS
|
||||
errorPayload.error = { message: errorPayload.message, type: errorType };
|
||||
delete errorPayload.message;
|
||||
} else if (service === "gcp") {
|
||||
// Try to standardize the error format for GCP
|
||||
if (errorPayload.error?.code) { // GCP Error
|
||||
errorPayload.error = { message: errorPayload.error.message, type: errorPayload.error.status || errorPayload.error.code };
|
||||
}
|
||||
}
|
||||
|
||||
if (statusCode === 400) {
|
||||
|
@ -225,6 +236,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
break;
|
||||
case "anthropic":
|
||||
case "aws":
|
||||
case "gcp":
|
||||
await handleAnthropicAwsBadRequestError(req, errorPayload);
|
||||
break;
|
||||
case "google-ai":
|
||||
|
@ -280,6 +292,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
default:
|
||||
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
|
||||
}
|
||||
return;
|
||||
case "gcp":
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
|
||||
return;
|
||||
}
|
||||
} else if (statusCode === 429) {
|
||||
switch (service) {
|
||||
|
@ -292,6 +309,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
case "aws":
|
||||
await handleAwsRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "gcp":
|
||||
await handleGcpRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "azure":
|
||||
case "mistral-ai":
|
||||
await handleAzureRateLimitError(req, errorPayload);
|
||||
|
@ -328,6 +348,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
case "gcp":
|
||||
errorPayload.proxy_note = `The requested GCP resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
case "azure":
|
||||
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
|
||||
break;
|
||||
|
@ -434,6 +457,19 @@ async function handleAwsRateLimitError(
|
|||
}
|
||||
}
|
||||
|
||||
async function handleGcpRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
if (errorPayload.error?.type === "RESOURCE_EXHAUSTED") {
|
||||
keyPool.markRateLimited(req.key!);
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError("GCP rate-limited request re-enqueued.");
|
||||
} else {
|
||||
errorPayload.proxy_note = `Unrecognized 429 Too Many Requests error from GCP.`;
|
||||
}
|
||||
}
|
||||
|
||||
async function handleOpenAIRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
|
|
|
@ -7,6 +7,7 @@ import { anthropic } from "./anthropic";
|
|||
import { googleAI } from "./google-ai";
|
||||
import { mistralAI } from "./mistral-ai";
|
||||
import { aws } from "./aws";
|
||||
import { gcp } from "./gcp";
|
||||
import { azure } from "./azure";
|
||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||
|
||||
|
@ -36,6 +37,7 @@ proxyRouter.use("/anthropic", addV1, anthropic);
|
|||
proxyRouter.use("/google-ai", addV1, googleAI);
|
||||
proxyRouter.use("/mistral-ai", addV1, mistralAI);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/gcp/claude", addV1, gcp);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
// Redirect browser requests to the homepage.
|
||||
proxyRouter.get("*", (req, res, next) => {
|
||||
|
|
|
@ -2,6 +2,7 @@ import { config, listConfig } from "./config";
|
|||
import {
|
||||
AnthropicKey,
|
||||
AwsBedrockKey,
|
||||
GcpKey,
|
||||
AzureOpenAIKey,
|
||||
GoogleAIKey,
|
||||
keyPool,
|
||||
|
@ -11,6 +12,7 @@ import {
|
|||
AnthropicModelFamily,
|
||||
assertIsKnownModelFamily,
|
||||
AwsBedrockModelFamily,
|
||||
GcpModelFamily,
|
||||
AzureOpenAIModelFamily,
|
||||
GoogleAIModelFamily,
|
||||
LLM_SERVICES,
|
||||
|
@ -40,6 +42,7 @@ const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
|||
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
|
||||
k.service === "mistral-ai";
|
||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||
const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp";
|
||||
|
||||
/** Stats aggregated across all keys for a given service. */
|
||||
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
|
||||
|
@ -52,7 +55,11 @@ type ModelAggregates = {
|
|||
pozzed?: number;
|
||||
awsLogged?: number;
|
||||
awsSonnet?: number;
|
||||
awsSonnet35?: number;
|
||||
awsHaiku?: number;
|
||||
gcpSonnet?: number;
|
||||
gcpSonnet35?: number;
|
||||
gcpHaiku?: number;
|
||||
queued: number;
|
||||
queueTime: string;
|
||||
tokens: number;
|
||||
|
@ -87,6 +94,12 @@ type AnthropicInfo = BaseFamilyInfo & {
|
|||
type AwsInfo = BaseFamilyInfo & {
|
||||
privacy?: string;
|
||||
sonnetKeys?: number;
|
||||
sonnet35Keys?: number;
|
||||
haikuKeys?: number;
|
||||
};
|
||||
type GcpInfo = BaseFamilyInfo & {
|
||||
sonnetKeys?: number;
|
||||
sonnet35Keys?: number;
|
||||
haikuKeys?: number;
|
||||
};
|
||||
|
||||
|
@ -101,6 +114,7 @@ export type ServiceInfo = {
|
|||
"google-ai"?: string;
|
||||
"mistral-ai"?: string;
|
||||
aws?: string;
|
||||
gcp?: string;
|
||||
azure?: string;
|
||||
"openai-image"?: string;
|
||||
"azure-image"?: string;
|
||||
|
@ -114,6 +128,7 @@ export type ServiceInfo = {
|
|||
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
|
||||
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
|
||||
& { [f in AwsBedrockModelFamily]?: AwsInfo }
|
||||
& { [f in GcpModelFamily]?: GcpInfo }
|
||||
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
|
||||
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
|
||||
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
|
||||
|
@ -151,6 +166,9 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
|||
aws: {
|
||||
aws: `%BASE%/aws/claude`,
|
||||
},
|
||||
gcp: {
|
||||
gcp: `%BASE%/gcp/claude`,
|
||||
},
|
||||
azure: {
|
||||
azure: `%BASE%/azure/openai`,
|
||||
"azure-image": `%BASE%/azure/openai`,
|
||||
|
@ -305,6 +323,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||
k.service === "mistral-ai" ? 1 : 0
|
||||
);
|
||||
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "gcp__keys", k.service === "gcp" ? 1 : 0);
|
||||
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
let sumTokens = 0;
|
||||
|
@ -396,6 +415,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
});
|
||||
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
|
||||
increment(modelStats, `aws-claude__awsSonnet35`, k.sonnet35Enabled ? 1 : 0);
|
||||
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
|
||||
|
||||
// Ignore revoked keys for aws logging stats, but include keys where the
|
||||
|
@ -405,6 +425,21 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
|
||||
break;
|
||||
}
|
||||
case "gcp": {
|
||||
if (!keyIsGcpKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
});
|
||||
increment(modelStats, `gcp-claude__gcpSonnet`, k.sonnetEnabled ? 1 : 0);
|
||||
increment(modelStats, `gcp-claude__gcpSonnet35`, k.sonnet35Enabled ? 1 : 0);
|
||||
increment(modelStats, `gcp-claude__gcpHaiku`, k.haikuEnabled ? 1 : 0);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assertNever(k.service);
|
||||
}
|
||||
|
@ -416,7 +451,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(family, tokens);
|
||||
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = {
|
||||
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
|
||||
activeKeys: modelStats.get(`${family}__active`) || 0,
|
||||
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||
|
@ -446,6 +481,7 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
|||
case "aws":
|
||||
if (family === "aws-claude") {
|
||||
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
|
||||
info.sonnet35Keys = modelStats.get(`${family}__awsSonnet35`) || 0;
|
||||
info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0;
|
||||
const logged = modelStats.get(`${family}__awsLogged`) || 0;
|
||||
if (logged > 0) {
|
||||
|
@ -455,6 +491,13 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
|||
}
|
||||
}
|
||||
break;
|
||||
case "gcp":
|
||||
if (family === "gcp-claude") {
|
||||
info.sonnetKeys = modelStats.get(`${family}__gcpSonnet`) || 0;
|
||||
info.sonnet35Keys = modelStats.get(`${family}__gcpSonnet35`) || 0;
|
||||
info.haikuKeys = modelStats.get(`${family}__gcpHaiku`) || 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
|||
{ key: key.hash, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 10 * 1000;
|
||||
const oneMinute = 60 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
import axios, { AxiosError } from "axios";
|
||||
import crypto from "crypto";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { GcpKey, GcpKeyProvider } from "./provider";
|
||||
import { GcpModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
||||
const GCP_HOST =
|
||||
process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
||||
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
|
||||
`https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
|
||||
const TEST_MESSAGES = [
|
||||
{ role: "user", content: "Hi!" },
|
||||
{ role: "assistant", content: "Hello!" },
|
||||
];
|
||||
|
||||
type UpdateFn = typeof GcpKeyProvider.prototype.update;
|
||||
|
||||
export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
||||
constructor(keys: GcpKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "gcp",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: GcpKey) {
|
||||
let checks: Promise<boolean>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
checks = [
|
||||
this.invokeModel("claude-3-haiku@20240307", key, true),
|
||||
this.invokeModel("claude-3-sonnet@20240229", key, true),
|
||||
this.invokeModel("claude-3-opus@20240229", key, true),
|
||||
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
|
||||
];
|
||||
|
||||
const [sonnet, haiku, opus, sonnet35] =
|
||||
await Promise.all(checks);
|
||||
|
||||
this.log.debug(
|
||||
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
|
||||
"GCP model initial tests complete."
|
||||
);
|
||||
|
||||
const families: GcpModelFamily[] = [];
|
||||
if (sonnet || sonnet35 || haiku) families.push("gcp-claude");
|
||||
if (opus) families.push("gcp-claude-opus");
|
||||
|
||||
if (families.length === 0) {
|
||||
this.log.warn(
|
||||
{ key: key.hash },
|
||||
"Key does not have access to any models; disabling."
|
||||
);
|
||||
return this.updateKey(key.hash, { isDisabled: true });
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {
|
||||
sonnetEnabled: sonnet,
|
||||
haikuEnabled: haiku,
|
||||
sonnet35Enabled: sonnet35,
|
||||
modelFamilies: families,
|
||||
});
|
||||
} else {
|
||||
if (key.haikuEnabled) {
|
||||
await this.invokeModel("claude-3-haiku@20240307", key, false)
|
||||
} else if (key.sonnetEnabled) {
|
||||
await this.invokeModel("claude-3-sonnet@20240229", key, false)
|
||||
} else if (key.sonnet35Enabled) {
|
||||
await this.invokeModel("claude-3-5-sonnet@20240620", key, false)
|
||||
} else {
|
||||
await this.invokeModel("claude-3-opus@20240229", key, false)
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
this.log.debug(
|
||||
{ key: key.hash},
|
||||
"GCP key check complete."
|
||||
);
|
||||
}
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
families: key.modelFamilies,
|
||||
},
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: GcpKey, error: AxiosError) {
|
||||
if (error.response && GcpKeyChecker.errorIsGcpError(error)) {
|
||||
const { status, data } = error.response;
|
||||
if (status === 400 || status === 401 || status === 403) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: data },
|
||||
"Key is invalid or revoked. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
} else if (status === 429) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: data },
|
||||
"Key is rate limited. Rechecking in a minute."
|
||||
);
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 60 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
} else {
|
||||
this.log.error(
|
||||
{ key: key.hash, status, error: data },
|
||||
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
|
||||
);
|
||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
return;
|
||||
}
|
||||
const { response, cause } = error;
|
||||
const { headers, status, data } = response ?? {};
|
||||
this.log.error(
|
||||
{ key: key.hash, status, headers, data, cause, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to invoke the given model with the given key. Returns true if the
|
||||
* key has access to the model, false if it does not. Throws an error if the
|
||||
* key is disabled.
|
||||
*/
|
||||
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
|
||||
const creds = GcpKeyChecker.getCredentialsFromKey(key);
|
||||
const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey)
|
||||
const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT)
|
||||
if (accessToken === null) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, jwtError },
|
||||
"Unable to get the access token"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
const payload = {
|
||||
max_tokens: 1,
|
||||
messages: TEST_MESSAGES,
|
||||
anthropic_version: "vertex-2023-10-16",
|
||||
};
|
||||
const { data, status } = await axios.post(
|
||||
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
|
||||
payload,
|
||||
{
|
||||
headers: GcpKeyChecker.getRequestHeaders(accessToken),
|
||||
validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300
|
||||
}
|
||||
);
|
||||
this.log.debug({ key: key.hash, data }, "Response from GCP");
|
||||
|
||||
if (initial) {
|
||||
return (status >= 200 && status < 300) || (status === 429 || status === 529);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static errorIsGcpError(error: AxiosError): error is AxiosError {
|
||||
const data = error.response?.data as any;
|
||||
if (Array.isArray(data)) {
|
||||
return data.length > 0 && data[0]?.error?.message;
|
||||
} else {
|
||||
return data?.error?.message;
|
||||
}
|
||||
}
|
||||
|
||||
static async createSignedJWT(email: string, pkey: string): Promise<string> {
|
||||
let cryptoKey = await crypto.subtle.importKey(
|
||||
"pkcs8",
|
||||
GcpKeyChecker.str2ab(atob(pkey)),
|
||||
{
|
||||
name: "RSASSA-PKCS1-v1_5",
|
||||
hash: { name: "SHA-256" },
|
||||
},
|
||||
false,
|
||||
["sign"]
|
||||
);
|
||||
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const issued = Math.floor(Date.now() / 1000);
|
||||
const expires = issued + 600;
|
||||
|
||||
const header = {
|
||||
alg: "RS256",
|
||||
typ: "JWT",
|
||||
};
|
||||
|
||||
const payload = {
|
||||
iss: email,
|
||||
aud: authUrl,
|
||||
iat: issued,
|
||||
exp: expires,
|
||||
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||
};
|
||||
|
||||
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header));
|
||||
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload));
|
||||
|
||||
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||
|
||||
const signature = await crypto.subtle.sign(
|
||||
"RSASSA-PKCS1-v1_5",
|
||||
cryptoKey,
|
||||
GcpKeyChecker.str2ab(unsignedToken)
|
||||
);
|
||||
|
||||
const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature);
|
||||
return `${unsignedToken}.${encodedSignature}`;
|
||||
}
|
||||
|
||||
static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> {
|
||||
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const params = {
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
assertion: signed_jwt,
|
||||
};
|
||||
|
||||
const r = await fetch(auth_url, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: Object.entries(params)
|
||||
.map(([k, v]) => `${k}=${v}`)
|
||||
.join("&"),
|
||||
}).then((res) => res.json());
|
||||
|
||||
if (r.access_token) {
|
||||
return [r.access_token, ""];
|
||||
}
|
||||
|
||||
return [null, JSON.stringify(r)];
|
||||
}
|
||||
|
||||
static str2ab(str: string): ArrayBuffer {
|
||||
const buffer = new ArrayBuffer(str.length);
|
||||
const bufferView = new Uint8Array(buffer);
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
bufferView[i] = str.charCodeAt(i);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||
let base64: string;
|
||||
if (typeof data === "string") {
|
||||
base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16))));
|
||||
} else {
|
||||
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||
}
|
||||
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||
}
|
||||
|
||||
static getRequestHeaders(accessToken: string) {
|
||||
return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" };
|
||||
}
|
||||
|
||||
static getCredentialsFromKey(key: GcpKey) {
|
||||
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
|
||||
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||
throw new Error("Invalid GCP key");
|
||||
}
|
||||
const privateKey = rawPrivateKey
|
||||
.replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '')
|
||||
.trim();
|
||||
|
||||
return { projectId, clientEmail, region, privateKey };
|
||||
}
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { GcpModelFamily, getGcpModelFamily } from "../../models";
|
||||
import { GcpKeyChecker } from "./checker";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
|
||||
type GcpKeyUsage = {
|
||||
[K in GcpModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface GcpKey extends Key, GcpKeyUsage {
|
||||
readonly service: "gcp";
|
||||
readonly modelFamilies: GcpModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
sonnetEnabled: boolean;
|
||||
haikuEnabled: boolean;
|
||||
sonnet35Enabled: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 4000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||
readonly service = "gcp";
|
||||
|
||||
private keys: GcpKey[] = [];
|
||||
private checker?: GcpKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.gcpCredentials?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"GCP_CREDENTIALS is not set. GCP API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: GcpKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["gcp-claude"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
hash: `gcp-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
sonnetEnabled: true,
|
||||
haikuEnabled: false,
|
||||
sonnet35Enabled: false,
|
||||
["gcp-claudeTokens"]: 0,
|
||||
["gcp-claude-opusTokens"]: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded GCP keys.");
|
||||
}
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new GcpKeyChecker(this.keys, this.update.bind(this));
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(model: string) {
|
||||
const neededFamily = getGcpModelFamily(model);
|
||||
|
||||
// this is a horrible mess
|
||||
// each of these should be separate model families, but adding model
|
||||
// families is not low enough friction for the rate at which gcp claude
|
||||
// model variants are added.
|
||||
const needsSonnet35 =
|
||||
model.includes("claude-3-5-sonnet") && neededFamily === "gcp-claude";
|
||||
const needsSonnet =
|
||||
!needsSonnet35 &&
|
||||
model.includes("sonnet") &&
|
||||
neededFamily === "gcp-claude";
|
||||
const needsHaiku = model.includes("haiku") && neededFamily === "gcp-claude";
|
||||
|
||||
const availableKeys = this.keys.filter((k) => {
|
||||
return (
|
||||
!k.isDisabled &&
|
||||
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under gcp-claude, while opus is not
|
||||
(k.haikuEnabled || !needsHaiku) &&
|
||||
(k.sonnet35Enabled || !needsSonnet35) &&
|
||||
k.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
});
|
||||
|
||||
this.log.debug(
|
||||
{
|
||||
model,
|
||||
neededFamily,
|
||||
needsSonnet,
|
||||
needsHaiku,
|
||||
needsSonnet35,
|
||||
availableKeys: availableKeys.length,
|
||||
totalKeys: this.keys.length,
|
||||
},
|
||||
"Selecting GCP key"
|
||||
);
|
||||
|
||||
if (availableKeys.length === 0) {
|
||||
throw new PaymentRequiredError(
|
||||
`No GCP keys available for model ${model}`
|
||||
);
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: GcpKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<GcpKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
// TODO: same exact behavior for three providers, should be refactored
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
* concurrent requests running on this key. We don't have any information on
|
||||
* when these requests will resolve, so all we can do is wait a bit and try
|
||||
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||
}
|
||||
|
||||
public recheck() {
|
||||
this.keys.forEach(({ hash }) =>
|
||||
this.update(hash, { lastChecked: 0, isDisabled: false, isRevoked: false })
|
||||
);
|
||||
this.checker?.scheduleNextCheck();
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
|
@ -63,4 +63,5 @@ export { AnthropicKey } from "./anthropic/provider";
|
|||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { GcpKey } from "./gcp/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
|
|
|
@ -10,6 +10,7 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
|||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { GcpKeyProvider } from "./gcp/provider";
|
||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||
import { MistralAIKeyProvider } from "./mistral-ai/provider";
|
||||
|
||||
|
@ -27,6 +28,7 @@ export class KeyPool {
|
|||
this.keyProviders.push(new GoogleAIKeyProvider());
|
||||
this.keyProviders.push(new MistralAIKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
this.keyProviders.push(new GcpKeyProvider());
|
||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||
}
|
||||
|
||||
|
@ -128,7 +130,11 @@ export class KeyPool {
|
|||
return "openai";
|
||||
} else if (model.startsWith("claude-")) {
|
||||
// https://console.anthropic.com/docs/api/reference#parameters
|
||||
return "anthropic";
|
||||
if (!model.includes('@')) {
|
||||
return "anthropic";
|
||||
} else {
|
||||
return "gcp";
|
||||
}
|
||||
} else if (model.includes("gemini")) {
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
return "google-ai";
|
||||
|
|
|
@ -5,7 +5,7 @@ import type { Request } from "express";
|
|||
|
||||
/**
|
||||
* The service that a model is hosted on. Distinct from `APIFormat` because some
|
||||
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
|
||||
* services have interoperable APIs (eg Anthropic/AWS/GCP, OpenAI/Azure).
|
||||
*/
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
|
@ -13,6 +13,7 @@ export type LLMService =
|
|||
| "google-ai"
|
||||
| "mistral-ai"
|
||||
| "aws"
|
||||
| "gcp"
|
||||
| "azure";
|
||||
|
||||
export type OpenAIModelFamily =
|
||||
|
@ -32,6 +33,7 @@ export type MistralAIModelFamily =
|
|||
// correspond to specific models. consider them rough pricing tiers.
|
||||
"mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large";
|
||||
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
|
||||
export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus";
|
||||
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
|
@ -39,6 +41,7 @@ export type ModelFamily =
|
|||
| GoogleAIModelFamily
|
||||
| MistralAIModelFamily
|
||||
| AwsBedrockModelFamily
|
||||
| GcpModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
|
||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
|
@ -61,6 +64,8 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
|||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"gcp-claude",
|
||||
"gcp-claude-opus",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
|
@ -77,6 +82,7 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
|||
"google-ai",
|
||||
"mistral-ai",
|
||||
"aws",
|
||||
"gcp",
|
||||
"azure",
|
||||
] as const);
|
||||
|
||||
|
@ -93,6 +99,8 @@ export const MODEL_FAMILY_SERVICE: {
|
|||
"claude-opus": "anthropic",
|
||||
"aws-claude": "aws",
|
||||
"aws-claude-opus": "aws",
|
||||
"gcp-claude": "gcp",
|
||||
"gcp-claude-opus": "gcp",
|
||||
"azure-turbo": "azure",
|
||||
"azure-gpt4": "azure",
|
||||
"azure-gpt4-32k": "azure",
|
||||
|
@ -176,6 +184,11 @@ export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
|
|||
return "aws-claude";
|
||||
}
|
||||
|
||||
export function getGcpModelFamily(model: string): GcpModelFamily {
|
||||
if (model.includes("opus")) return "gcp-claude-opus";
|
||||
return "gcp-claude";
|
||||
}
|
||||
|
||||
export function getAzureOpenAIModelFamily(
|
||||
model: string,
|
||||
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
|
||||
|
@ -210,10 +223,12 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
|||
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||
let modelFamily: ModelFamily;
|
||||
|
||||
// Weird special case for AWS/Azure because they serve multiple models from
|
||||
// Weird special case for AWS/GCP/Azure because they serve multiple models from
|
||||
// different vendors, even if currently only one is supported.
|
||||
if (req.service === "aws") {
|
||||
modelFamily = getAwsBedrockModelFamily(model);
|
||||
} else if (req.service === "gcp") {
|
||||
modelFamily = getGcpModelFamily(model);
|
||||
} else if (req.service === "azure") {
|
||||
modelFamily = getAzureOpenAIModelFamily(model);
|
||||
} else {
|
||||
|
|
|
@ -30,10 +30,12 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
|||
cost = 0.00001;
|
||||
break;
|
||||
case "aws-claude":
|
||||
case "gcp-claude":
|
||||
case "claude":
|
||||
cost = 0.000008;
|
||||
break;
|
||||
case "aws-claude-opus":
|
||||
case "gcp-claude-opus":
|
||||
case "claude-opus":
|
||||
cost = 0.000015;
|
||||
break;
|
||||
|
|
|
@ -13,6 +13,7 @@ import { v4 as uuid } from "uuid";
|
|||
import { config, getFirebaseApp } from "../../config";
|
||||
import {
|
||||
getAwsBedrockModelFamily,
|
||||
getGcpModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGoogleAIModelFamily,
|
||||
|
@ -417,6 +418,7 @@ function getModelFamilyForQuotaUsage(
|
|||
// differentiate between Azure and OpenAI variants of the same model.
|
||||
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
|
||||
if (model.includes("anthropic.")) return getAwsBedrockModelFamily(model);
|
||||
if (model.startsWith("claude-") && model.includes("@")) return getGcpModelFamily(model);
|
||||
|
||||
switch (api) {
|
||||
case "openai":
|
||||
|
|
Loading…
Reference in New Issue