Merge GCP Vertex AI implementation from cg-dot/oai-reverse-proxy (khanon/oai-reverse-proxy!72)

This commit is contained in:
khanon 2024-08-05 14:27:51 +00:00
parent 29ed07492e
commit 0c936e97fe
25 changed files with 1133 additions and 13 deletions

View File

@ -40,15 +40,21 @@ NODE_ENV=production
# Which model types users are allowed to access.
# The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus
# | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small
# | mistral-medium | mistral-large | aws-claude | aws-claude-opus | gcp-claude
# | gcp-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k
# | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
# 'azure-dall-e' to the list of allowed model families.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,gcp-claude,gcp-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
# Which services can be used to process prompts containing images via multimodal
# models. The following services are recognized:
# openai | anthropic | aws | azure | google-ai | mistral-ai
# openai | anthropic | aws | gcp | azure | google-ai | mistral-ai
# Do not enable this feature unless all users are trusted, as you will be liable
# for any user-submitted images containing illegal content.
# By default, no image services are allowed and image prompts are rejected.
@ -118,6 +124,7 @@ NODE_ENV=production
# TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_GEMINI_PRO=0
# TOKEN_QUOTA_AWS_CLAUDE=0
# TOKEN_QUOTA_GCP_CLAUDE=0
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
@ -142,6 +149,7 @@ NODE_ENV=production
# You can add multiple API keys by separating them with a comma.
# For AWS credentials, separate the access key ID, secret key, and region with a colon.
# For GCP credentials, separate the project ID, client email, region, and private key with a colon.
OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
@ -149,6 +157,7 @@ GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
GCP_CREDENTIALS=project-id:client-email:region:private-key
# With proxy_key gatekeeper, the password users must provide to access the API.
# PROXY_KEY=your-secret-key

View File

@ -20,6 +20,7 @@ This project allows you to run a reverse proxy server for various LLM APIs.
- [x] [OpenAI](https://openai.com/)
- [x] [Anthropic](https://www.anthropic.com/)
- [x] [AWS Bedrock](https://aws.amazon.com/bedrock/)
- [x] [Vertex AI (GCP)](https://cloud.google.com/vertex-ai/)
- [x] [Google MakerSuite/Gemini API](https://ai.google.dev/)
- [x] [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
- [x] Translation from OpenAI-formatted prompts to any other API, including streaming responses

35
docs/gcp-configuration.md Normal file
View File

@ -0,0 +1,35 @@
# Configuring the proxy for Vertex AI (GCP)
The proxy supports GCP models via the `/proxy/gcp/claude` endpoint. There are a few extra steps necessary to use GCP compared to the other supported APIs.
- [Setting keys](#setting-keys)
- [Setup Vertex AI](#setup-vertex-ai)
- [Supported model IDs](#supported-model-ids)
## Setting keys
Use the `GCP_CREDENTIALS` environment variable to set the GCP API keys.
Like other APIs, you can provide multiple keys separated by commas. Each GCP key, however, is a set of credentials including the project id, client email, region and private key. These are separated by a colon (`:`).
For example:
```
GCP_CREDENTIALS=my-first-project:xxx@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,my-first-project2:xxx2@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----
```
## Setup Vertex AI
1. Go to [https://cloud.google.com/vertex-ai](https://cloud.google.com/vertex-ai) and sign up for a GCP account. ($150 free credits without credit card or $300 free credits with credit card, credits expire in 90 days)
2. Go to [https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com](https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com) to enable Vertex AI API.
3. Go to [https://console.cloud.google.com/vertex-ai](https://console.cloud.google.com/vertex-ai) and navigate to Model Garden to apply for access to the Claude models.
4. Create a [Service Account](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1) , and make sure to grant the role of "Vertex AI User" or "Vertex AI Administrator".
5. On the service account page you just created, create a new key and select "JSON". The JSON file will be downloaded automatically.
6. The required credential is in the JSON file you just downloaded.
## Supported model IDs
Users can send these model IDs to the proxy to invoke the corresponding models.
- **Claude**
- `claude-3-haiku@20240307`
- `claude-3-sonnet@20240229`
- `claude-3-opus@20240229`
- `claude-3-5-sonnet@20240620`

View File

@ -230,6 +230,39 @@ Content-Type: application/json
]
}
###
# @name Proxy / GCP Claude -- Native Completion
POST {{proxy-host}}/proxy/gcp/claude/v1/complete
Authorization: Bearer {{proxy-key}}
anthropic-version: 2023-01-01
Content-Type: application/json
{
"model": "claude-v2",
"max_tokens_to_sample": 10,
"temperature": 0,
"stream": true,
"prompt": "What is genshin impact\n\n:Assistant:"
}
###
# @name Proxy / GCP Claude -- OpenAI-to-Anthropic API Translation
POST {{proxy-host}}/proxy/gcp/claude/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"max_tokens": 50,
"stream": true,
"messages": [
{
"role": "user",
"content": "What is genshin impact?"
}
]
}
###
# @name Proxy / Azure OpenAI -- Native Chat Completions
POST {{proxy-host}}/proxy/azure/openai/chat/completions

View File

@ -51,6 +51,8 @@ function getRandomModelFamily() {
"mistral-large",
"aws-claude",
"aws-claude-opus",
"gcp-claude",
"gcp-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",

View File

@ -268,7 +268,7 @@ router.post("/maintenance", (req, res) => {
let flash = { type: "", message: "" };
switch (action) {
case "recheck": {
const checkable: LLMService[] = ["openai", "anthropic", "aws", "azure"];
const checkable: LLMService[] = ["openai", "anthropic", "aws", "gcp","azure"];
checkable.forEach((s) => keyPool.recheck(s));
const keyCount = keyPool
.list()

View File

@ -45,6 +45,13 @@ type Config = {
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
*/
awsCredentials?: string;
/**
* Comma-delimited list of GCP credentials. Each credential item should be a
* colon-delimited list of access key, secret key, and GCP region.
*
* @example `GCP_CREDENTIALS=project1:1@1.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,project2:2@2.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----`
*/
gcpCredentials?: string;
/**
* Comma-delimited list of Azure OpenAI credentials. Each credential item
* should be a colon-delimited list of Azure resource name, deployment ID, and
@ -349,7 +356,7 @@ type Config = {
*
* Defaults to no services, meaning image prompts are disabled. Use a comma-
* separated list. Available services are:
* openai,anthropic,google-ai,mistral-ai,aws,azure
* openai,anthropic,google-ai,mistral-ai,aws,gcp,azure
*/
allowedVisionServices: LLMService[];
/**
@ -383,6 +390,7 @@ export const config: Config = {
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
gcpCredentials: getEnvWithDefault("GCP_CREDENTIALS", ""),
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
@ -437,6 +445,8 @@ export const config: Config = {
"mistral-large",
"aws-claude",
"aws-claude-opus",
"gcp-claude",
"gcp-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
@ -511,6 +521,7 @@ function generateSigningKey() {
config.googleAIKey,
config.mistralAIKey,
config.awsCredentials,
config.gcpCredentials,
config.azureCredentials,
];
if (secrets.filter((s) => s).length === 0) {
@ -648,6 +659,7 @@ export const OMITTED_KEYS = [
"googleAIKey",
"mistralAIKey",
"awsCredentials",
"gcpCredentials",
"azureCredentials",
"proxyKey",
"adminKey",
@ -738,6 +750,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
"ANTHROPIC_KEY",
"GOOGLE_AI_KEY",
"AWS_CREDENTIALS",
"GCP_CREDENTIALS",
"AZURE_CREDENTIALS",
].includes(String(env))
) {

View File

@ -29,6 +29,8 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"mistral-large": "Mistral Large",
"aws-claude": "AWS Claude (Sonnet)",
"aws-claude-opus": "AWS Claude (Opus)",
"gcp-claude": "GCP Claude (Sonnet)",
"gcp-claude-opus": "GCP Claude (Opus)",
"azure-turbo": "Azure GPT-3.5 Turbo",
"azure-gpt4": "Azure GPT-4",
"azure-gpt4-32k": "Azure GPT-4 32k",

View File

@ -129,7 +129,7 @@ export function transformAnthropicChatResponseToAnthropicText(
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformAnthropicTextResponseToOpenAI(
export function transformAnthropicTextResponseToOpenAI(
anthropicBody: Record<string, any>,
req: Request
): Record<string, any> {

196
src/proxy/gcp.ts Normal file
View File

@ -0,0 +1,196 @@
import { Request, RequestHandler, Response, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createPreprocessorMiddleware,
signGcpRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
import { sendErrorToClient } from "./middleware/response/error-generator";
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.gcpCredentials) return { object: "list", data: [] };
// https://docs.anthropic.com/en/docs/about-claude/models
const variants = [
"claude-3-haiku@20240307",
"claude-3-sonnet@20240229",
"claude-3-opus@20240229",
"claude-3-5-sonnet@20240620",
];
const models = variants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "anthropic",
permission: [],
root: "claude",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const gcpResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
switch (`${req.inboundApi}<-${req.outboundApi}`) {
case "openai<-anthropic-chat":
req.log.info("Transforming Anthropic Chat back to OpenAI format");
newBody = transformAnthropicChatResponseToOpenAI(body);
break;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
const gcpProxy = createQueueMiddleware({
beforeProxy: signGcpRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([gcpResponseHandler]),
error: handleProxyError,
},
}),
});
const oaiToChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-chat", service: "gcp" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
* or the new Claude chat completion endpoint, based on the requested model.
*/
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
oaiToChatPreprocessor(req, res, next);
};
const gcpRouter = Router();
gcpRouter.get("/v1/models", handleModelRequest);
// Native Anthropic chat completion endpoint.
gcpRouter.post(
"/v1/messages",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "gcp" },
{ afterTransform: [maybeReassignModel] }
),
gcpProxy
);
// OpenAI-to-GCP Anthropic compatibility endpoint.
gcpRouter.post(
"/v1/chat/completions",
ipLimiter,
preprocessOpenAICompatRequest,
gcpProxy
);
/**
* Tries to deal with:
* - frontends sending GCP model names even when they want to use the OpenAI-
* compatible endpoint
* - frontends sending Anthropic model names that GCP doesn't recognize
* - frontends sending OpenAI model names because they expect the proxy to
* translate them
*
* If client sends GCP model ID it will be used verbatim. Otherwise, various
* strategies are used to try to map a non-GCP model name to GCP model ID.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an GCP model, use it as-is
// if (model.includes("anthropic.claude")) {
if (model.startsWith("claude-") && model.includes("@")) {
return;
}
// Anthropic model names can look like:
// - claude-v1
// - claude-2.1
// - claude-3-5-sonnet-20240620-v1:0
const pattern =
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
const match = model.match(pattern);
// If there's no match, fallback to Claude3 Sonnet as it is most likely to be
// available on GCP.
if (!match) {
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
return;
}
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
const ver = minor ? `${major}.${minor}` : major;
switch (ver) {
case "3":
case "3.0":
if (name.includes("opus")) {
req.body.model = "claude-3-opus@20240229";
} else if (name.includes("haiku")) {
req.body.model = "claude-3-haiku@20240307";
} else {
req.body.model = "claude-3-sonnet@20240229";
}
return;
case "3.5":
req.body.model = "claude-3-5-sonnet@20240620";
return;
}
// Fallback to Claude3 Sonnet
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
return;
}
export const gcp = gcpRouter;

View File

@ -15,6 +15,7 @@ export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
export { languageFilter } from "./preprocessors/language-filter";
export { setApiFormat } from "./preprocessors/set-api-format";
export { signAwsRequest } from "./preprocessors/sign-aws-request";
export { signGcpRequest } from "./preprocessors/sign-vertex-ai-request";
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
export { validateContextSize } from "./preprocessors/validate-context-size";
export { validateVision } from "./preprocessors/validate-vision";

View File

@ -83,6 +83,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
proxyReq.setHeader("api-key", azureKey);
break;
case "aws":
case "gcp":
case "google-ai":
throw new Error("add-key should not be used for this service.");
default:

View File

@ -1,7 +1,7 @@
import type { HPMRequestCallback } from "../index";
/**
* For AWS/Azure/Google requests, the body is signed earlier in the request
* For AWS/GCP/Azure/Google requests, the body is signed earlier in the request
* pipeline, before the proxy middleware. This function just assigns the path
* and headers to the proxy request.
*/

View File

@ -0,0 +1,201 @@
import express from "express";
import crypto from "crypto";
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas";
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
export const signGcpRequest: RequestPreprocessor = async (req) => {
const serviceValid = req.service === "gcp";
if (!serviceValid) {
throw new Error("addVertexAIKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const { model, stream } = req.body;
req.key = keyPool.get(model, "gcp");
req.log.info({ key: req.key.hash, model }, "Assigned GCP key to request");
req.isStreaming = String(stream) === "true";
// TODO: This should happen in transform-outbound-payload.ts
// TODO: Support tools
let strippedParams: Record<string, unknown>;
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
system: true,
max_tokens: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
stream: true,
})
.strip()
.parse(req.body);
strippedParams.anthropic_version = "vertex-2023-10-16";
const [accessToken, credential] = await getAccessToken(req);
const host = GCP_HOST.replace("%REGION%", credential.region);
// GCP doesn't use the anthropic-version header, but we set it to ensure the
// stream adapter selects the correct transformer.
req.headers["anthropic-version"] = "2023-06-01";
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: host,
path: `/v1/projects/${credential.projectId}/locations/${credential.region}/publishers/anthropic/models/${model}:streamRawPredict`,
headers: {
["host"]: host,
["content-type"]: "application/json",
["authorization"]: `Bearer ${accessToken}`,
},
body: JSON.stringify(strippedParams),
};
};
async function getAccessToken(
req: express.Request
): Promise<[string, Credential]> {
// TODO: access token caching to reduce latency
const credential = getCredentialParts(req);
const signedJWT = await createSignedJWT(
credential.clientEmail,
credential.privateKey
);
const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT);
if (accessToken === null) {
req.log.warn(
{ key: req.key!.hash, jwtError },
"Unable to get the access token"
);
throw new Error("The access token is invalid.");
}
return [accessToken, credential];
}
async function createSignedJWT(email: string, pkey: string): Promise<string> {
let cryptoKey = await crypto.subtle.importKey(
"pkcs8",
str2ab(atob(pkey)),
{
name: "RSASSA-PKCS1-v1_5",
hash: { name: "SHA-256" },
},
false,
["sign"]
);
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = {
alg: "RS256",
typ: "JWT",
};
const payload = {
iss: email,
aud: authUrl,
iat: issued,
exp: expires,
scope: "https://www.googleapis.com/auth/cloud-platform",
};
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
const signature = await crypto.subtle.sign(
"RSASSA-PKCS1-v1_5",
cryptoKey,
str2ab(unsignedToken)
);
const encodedSignature = urlSafeBase64Encode(signature);
return `${unsignedToken}.${encodedSignature}`;
}
async function exchangeJwtForAccessToken(
signed_jwt: string
): Promise<[string | null, string]> {
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: signed_jwt,
};
const r = await fetch(auth_url, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: Object.entries(params)
.map(([k, v]) => `${k}=${v}`)
.join("&"),
}).then((res) => res.json());
if (r.access_token) {
return [r.access_token, ""];
}
return [null, JSON.stringify(r)];
}
function str2ab(str: string): ArrayBuffer {
const buffer = new ArrayBuffer(str.length);
const bufferView = new Uint8Array(buffer);
for (let i = 0; i < str.length; i++) {
bufferView[i] = str.charCodeAt(i);
}
return buffer;
}
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
}
type Credential = {
projectId: string;
clientEmail: string;
region: string;
privateKey: string;
};
function getCredentialParts(req: express.Request): Credential {
const [projectId, clientEmail, region, rawPrivateKey] =
req.key!.key.split(":");
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
req.log.error(
{ key: req.key!.hash },
"GCP_CREDENTIALS isn't correctly formatted; refer to the docs"
);
throw new Error("The key assigned to this request is invalid.");
}
const privateKey = rawPrivateKey
.replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.trim();
return { projectId, clientEmail, region, privateKey };
}

View File

@ -186,6 +186,13 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
throw new HttpError(statusCode, parseError.message);
}
const service = req.key!.service;
if (service === "gcp") {
if (Array.isArray(errorPayload)) {
errorPayload = errorPayload[0];
}
}
const errorType =
errorPayload.error?.code ||
errorPayload.error?.type ||
@ -199,11 +206,15 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
// TODO: split upstream error handling into separate modules for each service,
// this is out of control.
const service = req.key!.service;
if (service === "aws") {
// Try to standardize the error format for AWS
errorPayload.error = { message: errorPayload.message, type: errorType };
delete errorPayload.message;
} else if (service === "gcp") {
// Try to standardize the error format for GCP
if (errorPayload.error?.code) { // GCP Error
errorPayload.error = { message: errorPayload.error.message, type: errorPayload.error.status || errorPayload.error.code };
}
}
if (statusCode === 400) {
@ -225,6 +236,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
break;
case "anthropic":
case "aws":
case "gcp":
await handleAnthropicAwsBadRequestError(req, errorPayload);
break;
case "google-ai":
@ -280,6 +292,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
default:
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
}
return;
case "gcp":
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
return;
}
} else if (statusCode === 429) {
switch (service) {
@ -292,6 +309,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "aws":
await handleAwsRateLimitError(req, errorPayload);
break;
case "gcp":
await handleGcpRateLimitError(req, errorPayload);
break;
case "azure":
case "mistral-ai":
await handleAzureRateLimitError(req, errorPayload);
@ -328,6 +348,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
break;
case "gcp":
errorPayload.proxy_note = `The requested GCP resource might not exist, or the key might not have access to it.`;
break;
case "azure":
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
break;
@ -434,6 +457,19 @@ async function handleAwsRateLimitError(
}
}
async function handleGcpRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
if (errorPayload.error?.type === "RESOURCE_EXHAUSTED") {
keyPool.markRateLimited(req.key!);
await reenqueueRequest(req);
throw new RetryableError("GCP rate-limited request re-enqueued.");
} else {
errorPayload.proxy_note = `Unrecognized 429 Too Many Requests error from GCP.`;
}
}
async function handleOpenAIRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload

View File

@ -7,6 +7,7 @@ import { anthropic } from "./anthropic";
import { googleAI } from "./google-ai";
import { mistralAI } from "./mistral-ai";
import { aws } from "./aws";
import { gcp } from "./gcp";
import { azure } from "./azure";
import { sendErrorToClient } from "./middleware/response/error-generator";
@ -36,6 +37,7 @@ proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-ai", addV1, googleAI);
proxyRouter.use("/mistral-ai", addV1, mistralAI);
proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/gcp/claude", addV1, gcp);
proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage.
proxyRouter.get("*", (req, res, next) => {

View File

@ -2,6 +2,7 @@ import { config, listConfig } from "./config";
import {
AnthropicKey,
AwsBedrockKey,
GcpKey,
AzureOpenAIKey,
GoogleAIKey,
keyPool,
@ -11,6 +12,7 @@ import {
AnthropicModelFamily,
assertIsKnownModelFamily,
AwsBedrockModelFamily,
GcpModelFamily,
AzureOpenAIModelFamily,
GoogleAIModelFamily,
LLM_SERVICES,
@ -40,6 +42,7 @@ const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
k.service === "mistral-ai";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp";
/** Stats aggregated across all keys for a given service. */
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
@ -52,7 +55,11 @@ type ModelAggregates = {
pozzed?: number;
awsLogged?: number;
awsSonnet?: number;
awsSonnet35?: number;
awsHaiku?: number;
gcpSonnet?: number;
gcpSonnet35?: number;
gcpHaiku?: number;
queued: number;
queueTime: string;
tokens: number;
@ -87,6 +94,12 @@ type AnthropicInfo = BaseFamilyInfo & {
type AwsInfo = BaseFamilyInfo & {
privacy?: string;
sonnetKeys?: number;
sonnet35Keys?: number;
haikuKeys?: number;
};
type GcpInfo = BaseFamilyInfo & {
sonnetKeys?: number;
sonnet35Keys?: number;
haikuKeys?: number;
};
@ -101,6 +114,7 @@ export type ServiceInfo = {
"google-ai"?: string;
"mistral-ai"?: string;
aws?: string;
gcp?: string;
azure?: string;
"openai-image"?: string;
"azure-image"?: string;
@ -114,6 +128,7 @@ export type ServiceInfo = {
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
& { [f in AwsBedrockModelFamily]?: AwsInfo }
& { [f in GcpModelFamily]?: GcpInfo }
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
@ -151,6 +166,9 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
aws: {
aws: `%BASE%/aws/claude`,
},
gcp: {
gcp: `%BASE%/gcp/claude`,
},
azure: {
azure: `%BASE%/azure/openai`,
"azure-image": `%BASE%/azure/openai`,
@ -305,6 +323,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
k.service === "mistral-ai" ? 1 : 0
);
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "gcp__keys", k.service === "gcp" ? 1 : 0);
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
let sumTokens = 0;
@ -396,6 +415,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
increment(modelStats, `aws-claude__awsSonnet35`, k.sonnet35Enabled ? 1 : 0);
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
// Ignore revoked keys for aws logging stats, but include keys where the
@ -405,6 +425,21 @@ function addKeyToAggregates(k: KeyPoolKey) {
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
break;
}
case "gcp": {
if (!keyIsGcpKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
increment(modelStats, `gcp-claude__gcpSonnet`, k.sonnetEnabled ? 1 : 0);
increment(modelStats, `gcp-claude__gcpSonnet35`, k.sonnet35Enabled ? 1 : 0);
increment(modelStats, `gcp-claude__gcpHaiku`, k.haikuEnabled ? 1 : 0);
break;
}
default:
assertNever(k.service);
}
@ -416,7 +451,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const tokens = modelStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = {
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
activeKeys: modelStats.get(`${family}__active`) || 0,
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
@ -446,6 +481,7 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
case "aws":
if (family === "aws-claude") {
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
info.sonnet35Keys = modelStats.get(`${family}__awsSonnet35`) || 0;
info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0;
const logged = modelStats.get(`${family}__awsLogged`) || 0;
if (logged > 0) {
@ -455,6 +491,13 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
}
}
break;
case "gcp":
if (family === "gcp-claude") {
info.sonnetKeys = modelStats.get(`${family}__gcpSonnet`) || 0;
info.sonnet35Keys = modelStats.get(`${family}__gcpSonnet35`) || 0;
info.haikuKeys = modelStats.get(`${family}__gcpHaiku`) || 0;
}
break;
}
}

View File

@ -122,7 +122,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
{ key: key.hash, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 10 * 1000;
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}

View File

@ -0,0 +1,277 @@
import axios, { AxiosError } from "axios";
import crypto from "crypto";
import { KeyCheckerBase } from "../key-checker-base";
import type { GcpKey, GcpKeyProvider } from "./provider";
import { GcpModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const GCP_HOST =
process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
`https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
const TEST_MESSAGES = [
{ role: "user", content: "Hi!" },
{ role: "assistant", content: "Hello!" },
];
type UpdateFn = typeof GcpKeyProvider.prototype.update;
export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
constructor(keys: GcpKey[], updateKey: UpdateFn) {
super(keys, {
service: "gcp",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
updateKey,
});
}
protected async testKeyOrFail(key: GcpKey) {
let checks: Promise<boolean>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
checks = [
this.invokeModel("claude-3-haiku@20240307", key, true),
this.invokeModel("claude-3-sonnet@20240229", key, true),
this.invokeModel("claude-3-opus@20240229", key, true),
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
];
const [sonnet, haiku, opus, sonnet35] =
await Promise.all(checks);
this.log.debug(
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
"GCP model initial tests complete."
);
const families: GcpModelFamily[] = [];
if (sonnet || sonnet35 || haiku) families.push("gcp-claude");
if (opus) families.push("gcp-claude-opus");
if (families.length === 0) {
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
);
return this.updateKey(key.hash, { isDisabled: true });
}
this.updateKey(key.hash, {
sonnetEnabled: sonnet,
haikuEnabled: haiku,
sonnet35Enabled: sonnet35,
modelFamilies: families,
});
} else {
if (key.haikuEnabled) {
await this.invokeModel("claude-3-haiku@20240307", key, false)
} else if (key.sonnetEnabled) {
await this.invokeModel("claude-3-sonnet@20240229", key, false)
} else if (key.sonnet35Enabled) {
await this.invokeModel("claude-3-5-sonnet@20240620", key, false)
} else {
await this.invokeModel("claude-3-opus@20240229", key, false)
}
this.updateKey(key.hash, { lastChecked: Date.now() });
this.log.debug(
{ key: key.hash},
"GCP key check complete."
);
}
this.log.info(
{
key: key.hash,
families: key.modelFamilies,
},
"Checked key."
);
}
protected handleAxiosError(key: GcpKey, error: AxiosError) {
if (error.response && GcpKeyChecker.errorIsGcpError(error)) {
const { status, data } = error.response;
if (status === 400 || status === 401 || status === 403) {
this.log.warn(
{ key: key.hash, error: data },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
} else if (status === 429) {
this.log.warn(
{ key: key.hash, error: data },
"Key is rate limited. Rechecking in a minute."
);
const next = Date.now() - (KEY_CHECK_PERIOD - 60 * 1000);
this.updateKey(key.hash, { lastChecked: next });
} else {
this.log.error(
{ key: key.hash, status, error: data },
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
}
return;
}
const { response, cause } = error;
const { headers, status, data } = response ?? {};
this.log.error(
{ key: key.hash, status, headers, data, cause, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
/**
* Attempt to invoke the given model with the given key. Returns true if the
* key has access to the model, false if it does not. Throws an error if the
* key is disabled.
*/
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
const creds = GcpKeyChecker.getCredentialsFromKey(key);
const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey)
const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT)
if (accessToken === null) {
this.log.warn(
{ key: key.hash, jwtError },
"Unable to get the access token"
);
return false;
}
const payload = {
max_tokens: 1,
messages: TEST_MESSAGES,
anthropic_version: "vertex-2023-10-16",
};
const { data, status } = await axios.post(
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
payload,
{
headers: GcpKeyChecker.getRequestHeaders(accessToken),
validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300
}
);
this.log.debug({ key: key.hash, data }, "Response from GCP");
if (initial) {
return (status >= 200 && status < 300) || (status === 429 || status === 529);
}
return true;
}
static errorIsGcpError(error: AxiosError): error is AxiosError {
const data = error.response?.data as any;
if (Array.isArray(data)) {
return data.length > 0 && data[0]?.error?.message;
} else {
return data?.error?.message;
}
}
static async createSignedJWT(email: string, pkey: string): Promise<string> {
let cryptoKey = await crypto.subtle.importKey(
"pkcs8",
GcpKeyChecker.str2ab(atob(pkey)),
{
name: "RSASSA-PKCS1-v1_5",
hash: { name: "SHA-256" },
},
false,
["sign"]
);
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = {
alg: "RS256",
typ: "JWT",
};
const payload = {
iss: email,
aud: authUrl,
iat: issued,
exp: expires,
scope: "https://www.googleapis.com/auth/cloud-platform",
};
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header));
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload));
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
const signature = await crypto.subtle.sign(
"RSASSA-PKCS1-v1_5",
cryptoKey,
GcpKeyChecker.str2ab(unsignedToken)
);
const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature);
return `${unsignedToken}.${encodedSignature}`;
}
static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> {
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: signed_jwt,
};
const r = await fetch(auth_url, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: Object.entries(params)
.map(([k, v]) => `${k}=${v}`)
.join("&"),
}).then((res) => res.json());
if (r.access_token) {
return [r.access_token, ""];
}
return [null, JSON.stringify(r)];
}
static str2ab(str: string): ArrayBuffer {
const buffer = new ArrayBuffer(str.length);
const bufferView = new Uint8Array(buffer);
for (let i = 0; i < str.length; i++) {
bufferView[i] = str.charCodeAt(i);
}
return buffer;
}
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16))));
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
}
static getRequestHeaders(accessToken: string) {
return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" };
}
static getCredentialsFromKey(key: GcpKey) {
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
throw new Error("Invalid GCP key");
}
const privateKey = rawPrivateKey
.replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '')
.trim();
return { projectId, clientEmail, region, privateKey };
}
}

View File

@ -0,0 +1,242 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { GcpModelFamily, getGcpModelFamily } from "../../models";
import { GcpKeyChecker } from "./checker";
import { PaymentRequiredError } from "../../errors";
type GcpKeyUsage = {
[K in GcpModelFamily as `${K}Tokens`]: number;
};
export interface GcpKey extends Key, GcpKeyUsage {
readonly service: "gcp";
readonly modelFamilies: GcpModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
sonnetEnabled: boolean;
haikuEnabled: boolean;
sonnet35Enabled: boolean;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 4000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class GcpKeyProvider implements KeyProvider<GcpKey> {
readonly service = "gcp";
private keys: GcpKey[] = [];
private checker?: GcpKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.gcpCredentials?.trim();
if (!keyConfig) {
this.log.warn(
"GCP_CREDENTIALS is not set. GCP API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: GcpKey = {
key,
service: this.service,
modelFamilies: ["gcp-claude"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `gcp-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
sonnetEnabled: true,
haikuEnabled: false,
sonnet35Enabled: false,
["gcp-claudeTokens"]: 0,
["gcp-claude-opusTokens"]: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded GCP keys.");
}
public init() {
if (config.checkKeys) {
this.checker = new GcpKeyChecker(this.keys, this.update.bind(this));
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(model: string) {
const neededFamily = getGcpModelFamily(model);
// this is a horrible mess
// each of these should be separate model families, but adding model
// families is not low enough friction for the rate at which gcp claude
// model variants are added.
const needsSonnet35 =
model.includes("claude-3-5-sonnet") && neededFamily === "gcp-claude";
const needsSonnet =
!needsSonnet35 &&
model.includes("sonnet") &&
neededFamily === "gcp-claude";
const needsHaiku = model.includes("haiku") && neededFamily === "gcp-claude";
const availableKeys = this.keys.filter((k) => {
return (
!k.isDisabled &&
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under gcp-claude, while opus is not
(k.haikuEnabled || !needsHaiku) &&
(k.sonnet35Enabled || !needsSonnet35) &&
k.modelFamilies.includes(neededFamily)
);
});
this.log.debug(
{
model,
neededFamily,
needsSonnet,
needsHaiku,
needsSonnet35,
availableKeys: availableKeys.length,
totalKeys: this.keys.length,
},
"Selecting GCP key"
);
if (availableKeys.length === 0) {
throw new PaymentRequiredError(
`No GCP keys available for model ${model}`
);
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
public disable(key: GcpKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<GcpKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
// TODO: same exact behavior for three providers, should be refactored
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve, so all we can do is wait a bit and try
* again. We will lock the key for 2 seconds after getting a 429 before
* retrying in order to give the other requests a chance to finish.
*/
public markRateLimited(keyHash: string) {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public recheck() {
this.keys.forEach(({ hash }) =>
this.update(hash, { lastChecked: 0, isDisabled: false, isRevoked: false })
);
this.checker?.scheduleNextCheck();
}
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}

View File

@ -63,4 +63,5 @@ export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { AwsBedrockKey } from "./aws/provider";
export { GcpKey } from "./gcp/provider";
export { AzureOpenAIKey } from "./azure/provider";

View File

@ -10,6 +10,7 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GoogleAIKeyProvider } from "./google-ai/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { GcpKeyProvider } from "./gcp/provider";
import { AzureOpenAIKeyProvider } from "./azure/provider";
import { MistralAIKeyProvider } from "./mistral-ai/provider";
@ -27,6 +28,7 @@ export class KeyPool {
this.keyProviders.push(new GoogleAIKeyProvider());
this.keyProviders.push(new MistralAIKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new GcpKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider());
}
@ -128,7 +130,11 @@ export class KeyPool {
return "openai";
} else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters
return "anthropic";
if (!model.includes('@')) {
return "anthropic";
} else {
return "gcp";
}
} else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language
return "google-ai";

View File

@ -5,7 +5,7 @@ import type { Request } from "express";
/**
* The service that a model is hosted on. Distinct from `APIFormat` because some
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
* services have interoperable APIs (eg Anthropic/AWS/GCP, OpenAI/Azure).
*/
export type LLMService =
| "openai"
@ -13,6 +13,7 @@ export type LLMService =
| "google-ai"
| "mistral-ai"
| "aws"
| "gcp"
| "azure";
export type OpenAIModelFamily =
@ -32,6 +33,7 @@ export type MistralAIModelFamily =
// correspond to specific models. consider them rough pricing tiers.
"mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large";
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus";
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
export type ModelFamily =
| OpenAIModelFamily
@ -39,6 +41,7 @@ export type ModelFamily =
| GoogleAIModelFamily
| MistralAIModelFamily
| AwsBedrockModelFamily
| GcpModelFamily
| AzureOpenAIModelFamily;
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
@ -61,6 +64,8 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"mistral-large",
"aws-claude",
"aws-claude-opus",
"gcp-claude",
"gcp-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
@ -77,6 +82,7 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
"google-ai",
"mistral-ai",
"aws",
"gcp",
"azure",
] as const);
@ -93,6 +99,8 @@ export const MODEL_FAMILY_SERVICE: {
"claude-opus": "anthropic",
"aws-claude": "aws",
"aws-claude-opus": "aws",
"gcp-claude": "gcp",
"gcp-claude-opus": "gcp",
"azure-turbo": "azure",
"azure-gpt4": "azure",
"azure-gpt4-32k": "azure",
@ -176,6 +184,11 @@ export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
return "aws-claude";
}
export function getGcpModelFamily(model: string): GcpModelFamily {
if (model.includes("opus")) return "gcp-claude-opus";
return "gcp-claude";
}
export function getAzureOpenAIModelFamily(
model: string,
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
@ -210,10 +223,12 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
const model = req.body.model ?? "gpt-3.5-turbo";
let modelFamily: ModelFamily;
// Weird special case for AWS/Azure because they serve multiple models from
// Weird special case for AWS/GCP/Azure because they serve multiple models from
// different vendors, even if currently only one is supported.
if (req.service === "aws") {
modelFamily = getAwsBedrockModelFamily(model);
} else if (req.service === "gcp") {
modelFamily = getGcpModelFamily(model);
} else if (req.service === "azure") {
modelFamily = getAzureOpenAIModelFamily(model);
} else {

View File

@ -30,10 +30,12 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
cost = 0.00001;
break;
case "aws-claude":
case "gcp-claude":
case "claude":
cost = 0.000008;
break;
case "aws-claude-opus":
case "gcp-claude-opus":
case "claude-opus":
cost = 0.000015;
break;

View File

@ -13,6 +13,7 @@ import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config";
import {
getAwsBedrockModelFamily,
getGcpModelFamily,
getAzureOpenAIModelFamily,
getClaudeModelFamily,
getGoogleAIModelFamily,
@ -417,6 +418,7 @@ function getModelFamilyForQuotaUsage(
// differentiate between Azure and OpenAI variants of the same model.
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
if (model.includes("anthropic.")) return getAwsBedrockModelFamily(model);
if (model.startsWith("claude-") && model.includes("@")) return getGcpModelFamily(model);
switch (api) {
case "openai":