From 0c936e97fed28e709f6ca17c68604b0cd62af189 Mon Sep 17 00:00:00 2001 From: khanon Date: Mon, 5 Aug 2024 14:27:51 +0000 Subject: [PATCH] Merge GCP Vertex AI implementation from cg-dot/oai-reverse-proxy (khanon/oai-reverse-proxy!72) --- .env.example | 15 +- README.md | 1 + docs/gcp-configuration.md | 35 +++ scripts/oai-reverse-proxy.http | 33 +++ scripts/seed-events.ts | 2 + src/admin/web/manage.ts | 2 +- src/config.ts | 15 +- src/info-page.ts | 2 + src/proxy/anthropic.ts | 2 +- src/proxy/gcp.ts | 196 +++++++++++++ src/proxy/middleware/request/index.ts | 1 + .../middleware/request/onproxyreq/add-key.ts | 1 + .../onproxyreq/finalize-signed-request.ts | 2 +- .../preprocessors/sign-vertex-ai-request.ts | 201 +++++++++++++ src/proxy/middleware/response/index.ts | 38 ++- src/proxy/routes.ts | 2 + src/service-info.ts | 45 ++- .../key-management/anthropic/checker.ts | 2 +- src/shared/key-management/gcp/checker.ts | 277 ++++++++++++++++++ src/shared/key-management/gcp/provider.ts | 242 +++++++++++++++ src/shared/key-management/index.ts | 1 + src/shared/key-management/key-pool.ts | 8 +- src/shared/models.ts | 19 +- src/shared/stats.ts | 2 + src/shared/users/user-store.ts | 2 + 25 files changed, 1133 insertions(+), 13 deletions(-) create mode 100644 docs/gcp-configuration.md create mode 100644 src/proxy/gcp.ts create mode 100644 src/proxy/middleware/request/preprocessors/sign-vertex-ai-request.ts create mode 100644 src/shared/key-management/gcp/checker.ts create mode 100644 src/shared/key-management/gcp/provider.ts diff --git a/.env.example b/.env.example index 046c053..228be1a 100644 --- a/.env.example +++ b/.env.example @@ -40,15 +40,21 @@ NODE_ENV=production # Which model types users are allowed to access. # The following model families are recognized: -# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-dall-e + +# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus +# | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small +# | mistral-medium | mistral-large | aws-claude | aws-claude-opus | gcp-claude +# | gcp-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k +# | azure-gpt4-turbo | azure-gpt4o | azure-dall-e + # By default, all models are allowed except for 'dall-e' / 'azure-dall-e'. # To allow DALL-E image generation, uncomment the line below and add 'dall-e' or # 'azure-dall-e' to the list of allowed model families. -# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o +# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,gcp-claude,gcp-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o # Which services can be used to process prompts containing images via multimodal # models. The following services are recognized: -# openai | anthropic | aws | azure | google-ai | mistral-ai +# openai | anthropic | aws | gcp | azure | google-ai | mistral-ai # Do not enable this feature unless all users are trusted, as you will be liable # for any user-submitted images containing illegal content. # By default, no image services are allowed and image prompts are rejected. @@ -118,6 +124,7 @@ NODE_ENV=production # TOKEN_QUOTA_CLAUDE=0 # TOKEN_QUOTA_GEMINI_PRO=0 # TOKEN_QUOTA_AWS_CLAUDE=0 +# TOKEN_QUOTA_GCP_CLAUDE=0 # "Tokens" for image-generation models are counted at a rate of 100000 tokens # per US$1.00 generated, which is similar to the cost of GPT-4 Turbo. # DALL-E 3 costs around US$0.10 per image (10000 tokens). @@ -142,6 +149,7 @@ NODE_ENV=production # You can add multiple API keys by separating them with a comma. # For AWS credentials, separate the access key ID, secret key, and region with a colon. +# For GCP credentials, separate the project ID, client email, region, and private key with a colon. OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx @@ -149,6 +157,7 @@ GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2 # See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure. AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key +GCP_CREDENTIALS=project-id:client-email:region:private-key # With proxy_key gatekeeper, the password users must provide to access the API. # PROXY_KEY=your-secret-key diff --git a/README.md b/README.md index f44e948..e7a3310 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ This project allows you to run a reverse proxy server for various LLM APIs. - [x] [OpenAI](https://openai.com/) - [x] [Anthropic](https://www.anthropic.com/) - [x] [AWS Bedrock](https://aws.amazon.com/bedrock/) + - [x] [Vertex AI (GCP)](https://cloud.google.com/vertex-ai/) - [x] [Google MakerSuite/Gemini API](https://ai.google.dev/) - [x] [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) - [x] Translation from OpenAI-formatted prompts to any other API, including streaming responses diff --git a/docs/gcp-configuration.md b/docs/gcp-configuration.md new file mode 100644 index 0000000..bf5ceb2 --- /dev/null +++ b/docs/gcp-configuration.md @@ -0,0 +1,35 @@ +# Configuring the proxy for Vertex AI (GCP) + +The proxy supports GCP models via the `/proxy/gcp/claude` endpoint. There are a few extra steps necessary to use GCP compared to the other supported APIs. + +- [Setting keys](#setting-keys) +- [Setup Vertex AI](#setup-vertex-ai) +- [Supported model IDs](#supported-model-ids) + +## Setting keys + +Use the `GCP_CREDENTIALS` environment variable to set the GCP API keys. + +Like other APIs, you can provide multiple keys separated by commas. Each GCP key, however, is a set of credentials including the project id, client email, region and private key. These are separated by a colon (`:`). + +For example: + +``` +GCP_CREDENTIALS=my-first-project:xxx@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,my-first-project2:xxx2@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY----- +``` + +## Setup Vertex AI +1. Go to [https://cloud.google.com/vertex-ai](https://cloud.google.com/vertex-ai) and sign up for a GCP account. ($150 free credits without credit card or $300 free credits with credit card, credits expire in 90 days) +2. Go to [https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com](https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com) to enable Vertex AI API. +3. Go to [https://console.cloud.google.com/vertex-ai](https://console.cloud.google.com/vertex-ai) and navigate to Model Garden to apply for access to the Claude models. +4. Create a [Service Account](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1) , and make sure to grant the role of "Vertex AI User" or "Vertex AI Administrator". +5. On the service account page you just created, create a new key and select "JSON". The JSON file will be downloaded automatically. +6. The required credential is in the JSON file you just downloaded. + +## Supported model IDs +Users can send these model IDs to the proxy to invoke the corresponding models. +- **Claude** + - `claude-3-haiku@20240307` + - `claude-3-sonnet@20240229` + - `claude-3-opus@20240229` + - `claude-3-5-sonnet@20240620` \ No newline at end of file diff --git a/scripts/oai-reverse-proxy.http b/scripts/oai-reverse-proxy.http index 381fb2b..18b0942 100644 --- a/scripts/oai-reverse-proxy.http +++ b/scripts/oai-reverse-proxy.http @@ -230,6 +230,39 @@ Content-Type: application/json ] } +### +# @name Proxy / GCP Claude -- Native Completion +POST {{proxy-host}}/proxy/gcp/claude/v1/complete +Authorization: Bearer {{proxy-key}} +anthropic-version: 2023-01-01 +Content-Type: application/json + +{ + "model": "claude-v2", + "max_tokens_to_sample": 10, + "temperature": 0, + "stream": true, + "prompt": "What is genshin impact\n\n:Assistant:" +} + +### +# @name Proxy / GCP Claude -- OpenAI-to-Anthropic API Translation +POST {{proxy-host}}/proxy/gcp/claude/chat/completions +Authorization: Bearer {{proxy-key}} +Content-Type: application/json + +{ + "model": "gpt-3.5-turbo", + "max_tokens": 50, + "stream": true, + "messages": [ + { + "role": "user", + "content": "What is genshin impact?" + } + ] +} + ### # @name Proxy / Azure OpenAI -- Native Chat Completions POST {{proxy-host}}/proxy/azure/openai/chat/completions diff --git a/scripts/seed-events.ts b/scripts/seed-events.ts index 26e6f80..328d3b4 100644 --- a/scripts/seed-events.ts +++ b/scripts/seed-events.ts @@ -51,6 +51,8 @@ function getRandomModelFamily() { "mistral-large", "aws-claude", "aws-claude-opus", + "gcp-claude", + "gcp-claude-opus", "azure-turbo", "azure-gpt4", "azure-gpt4-32k", diff --git a/src/admin/web/manage.ts b/src/admin/web/manage.ts index bdeab5d..029bbac 100644 --- a/src/admin/web/manage.ts +++ b/src/admin/web/manage.ts @@ -268,7 +268,7 @@ router.post("/maintenance", (req, res) => { let flash = { type: "", message: "" }; switch (action) { case "recheck": { - const checkable: LLMService[] = ["openai", "anthropic", "aws", "azure"]; + const checkable: LLMService[] = ["openai", "anthropic", "aws", "gcp","azure"]; checkable.forEach((s) => keyPool.recheck(s)); const keyCount = keyPool .list() diff --git a/src/config.ts b/src/config.ts index 9050a19..8318522 100644 --- a/src/config.ts +++ b/src/config.ts @@ -45,6 +45,13 @@ type Config = { * @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2` */ awsCredentials?: string; + /** + * Comma-delimited list of GCP credentials. Each credential item should be a + * colon-delimited list of access key, secret key, and GCP region. + * + * @example `GCP_CREDENTIALS=project1:1@1.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,project2:2@2.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----` + */ + gcpCredentials?: string; /** * Comma-delimited list of Azure OpenAI credentials. Each credential item * should be a colon-delimited list of Azure resource name, deployment ID, and @@ -349,7 +356,7 @@ type Config = { * * Defaults to no services, meaning image prompts are disabled. Use a comma- * separated list. Available services are: - * openai,anthropic,google-ai,mistral-ai,aws,azure + * openai,anthropic,google-ai,mistral-ai,aws,gcp,azure */ allowedVisionServices: LLMService[]; /** @@ -383,6 +390,7 @@ export const config: Config = { googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""), mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""), awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""), + gcpCredentials: getEnvWithDefault("GCP_CREDENTIALS", ""), azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""), adminKey: getEnvWithDefault("ADMIN_KEY", ""), @@ -437,6 +445,8 @@ export const config: Config = { "mistral-large", "aws-claude", "aws-claude-opus", + "gcp-claude", + "gcp-claude-opus", "azure-turbo", "azure-gpt4", "azure-gpt4-32k", @@ -511,6 +521,7 @@ function generateSigningKey() { config.googleAIKey, config.mistralAIKey, config.awsCredentials, + config.gcpCredentials, config.azureCredentials, ]; if (secrets.filter((s) => s).length === 0) { @@ -648,6 +659,7 @@ export const OMITTED_KEYS = [ "googleAIKey", "mistralAIKey", "awsCredentials", + "gcpCredentials", "azureCredentials", "proxyKey", "adminKey", @@ -738,6 +750,7 @@ function getEnvWithDefault(env: string | string[], defaultValue: T): T { "ANTHROPIC_KEY", "GOOGLE_AI_KEY", "AWS_CREDENTIALS", + "GCP_CREDENTIALS", "AZURE_CREDENTIALS", ].includes(String(env)) ) { diff --git a/src/info-page.ts b/src/info-page.ts index ce2f7c2..2f0a7f1 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -29,6 +29,8 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { "mistral-large": "Mistral Large", "aws-claude": "AWS Claude (Sonnet)", "aws-claude-opus": "AWS Claude (Opus)", + "gcp-claude": "GCP Claude (Sonnet)", + "gcp-claude-opus": "GCP Claude (Opus)", "azure-turbo": "Azure GPT-3.5 Turbo", "azure-gpt4": "Azure GPT-4", "azure-gpt4-32k": "Azure GPT-4 32k", diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index d537800..643efef 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -129,7 +129,7 @@ export function transformAnthropicChatResponseToAnthropicText( * is only used for non-streaming requests as streaming requests are handled * on-the-fly. */ -function transformAnthropicTextResponseToOpenAI( +export function transformAnthropicTextResponseToOpenAI( anthropicBody: Record, req: Request ): Record { diff --git a/src/proxy/gcp.ts b/src/proxy/gcp.ts new file mode 100644 index 0000000..a9b622b --- /dev/null +++ b/src/proxy/gcp.ts @@ -0,0 +1,196 @@ +import { Request, RequestHandler, Response, Router } from "express"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { v4 } from "uuid"; +import { config } from "../config"; +import { logger } from "../logger"; +import { createQueueMiddleware } from "./queue"; +import { ipLimiter } from "./rate-limit"; +import { handleProxyError } from "./middleware/common"; +import { + createPreprocessorMiddleware, + signGcpRequest, + finalizeSignedRequest, + createOnProxyReqHandler, +} from "./middleware/request"; +import { + ProxyResHandlerWithBody, + createOnProxyResHandler, +} from "./middleware/response"; +import { transformAnthropicChatResponseToOpenAI } from "./anthropic"; +import { sendErrorToClient } from "./middleware/response/error-generator"; + +const LATEST_GCP_SONNET_MINOR_VERSION = "20240229"; + +let modelsCache: any = null; +let modelsCacheTime = 0; + +const getModelsResponse = () => { + if (new Date().getTime() - modelsCacheTime < 1000 * 60) { + return modelsCache; + } + + if (!config.gcpCredentials) return { object: "list", data: [] }; + + // https://docs.anthropic.com/en/docs/about-claude/models + const variants = [ + "claude-3-haiku@20240307", + "claude-3-sonnet@20240229", + "claude-3-opus@20240229", + "claude-3-5-sonnet@20240620", + ]; + + const models = variants.map((id) => ({ + id, + object: "model", + created: new Date().getTime(), + owned_by: "anthropic", + permission: [], + root: "claude", + parent: null, + })); + + modelsCache = { object: "list", data: models }; + modelsCacheTime = new Date().getTime(); + + return modelsCache; +}; + +const handleModelRequest: RequestHandler = (_req, res) => { + res.status(200).json(getModelsResponse()); +}; + +/** Only used for non-streaming requests. */ +const gcpResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + let newBody = body; + switch (`${req.inboundApi}<-${req.outboundApi}`) { + case "openai<-anthropic-chat": + req.log.info("Transforming Anthropic Chat back to OpenAI format"); + newBody = transformAnthropicChatResponseToOpenAI(body); + break; + } + + res.status(200).json({ ...newBody, proxy: body.proxy }); +}; + +const gcpProxy = createQueueMiddleware({ + beforeProxy: signGcpRequest, + proxyMiddleware: createProxyMiddleware({ + target: "bad-target-will-be-rewritten", + router: ({ signedRequest }) => { + if (!signedRequest) throw new Error("Must sign request before proxying"); + return `${signedRequest.protocol}//${signedRequest.hostname}`; + }, + changeOrigin: true, + selfHandleResponse: true, + logger, + on: { + proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), + proxyRes: createOnProxyResHandler([gcpResponseHandler]), + error: handleProxyError, + }, + }), +}); + +const oaiToChatPreprocessor = createPreprocessorMiddleware( + { inApi: "openai", outApi: "anthropic-chat", service: "gcp" }, + { afterTransform: [maybeReassignModel] } +); + +/** + * Routes an OpenAI prompt to either the legacy Claude text completion endpoint + * or the new Claude chat completion endpoint, based on the requested model. + */ +const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { + oaiToChatPreprocessor(req, res, next); +}; + +const gcpRouter = Router(); +gcpRouter.get("/v1/models", handleModelRequest); +// Native Anthropic chat completion endpoint. +gcpRouter.post( + "/v1/messages", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "gcp" }, + { afterTransform: [maybeReassignModel] } + ), + gcpProxy +); + +// OpenAI-to-GCP Anthropic compatibility endpoint. +gcpRouter.post( + "/v1/chat/completions", + ipLimiter, + preprocessOpenAICompatRequest, + gcpProxy +); + +/** + * Tries to deal with: + * - frontends sending GCP model names even when they want to use the OpenAI- + * compatible endpoint + * - frontends sending Anthropic model names that GCP doesn't recognize + * - frontends sending OpenAI model names because they expect the proxy to + * translate them + * + * If client sends GCP model ID it will be used verbatim. Otherwise, various + * strategies are used to try to map a non-GCP model name to GCP model ID. + */ +function maybeReassignModel(req: Request) { + const model = req.body.model; + + // If it looks like an GCP model, use it as-is + // if (model.includes("anthropic.claude")) { + if (model.startsWith("claude-") && model.includes("@")) { + return; + } + + // Anthropic model names can look like: + // - claude-v1 + // - claude-2.1 + // - claude-3-5-sonnet-20240620-v1:0 + const pattern = + /^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i; + const match = model.match(pattern); + + // If there's no match, fallback to Claude3 Sonnet as it is most likely to be + // available on GCP. + if (!match) { + req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`; + return; + } + + const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match; + + const ver = minor ? `${major}.${minor}` : major; + switch (ver) { + case "3": + case "3.0": + if (name.includes("opus")) { + req.body.model = "claude-3-opus@20240229"; + } else if (name.includes("haiku")) { + req.body.model = "claude-3-haiku@20240307"; + } else { + req.body.model = "claude-3-sonnet@20240229"; + } + return; + case "3.5": + req.body.model = "claude-3-5-sonnet@20240620"; + return; + } + + // Fallback to Claude3 Sonnet + req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`; + return; +} + +export const gcp = gcpRouter; diff --git a/src/proxy/middleware/request/index.ts b/src/proxy/middleware/request/index.ts index f11c4f3..e2aba44 100644 --- a/src/proxy/middleware/request/index.ts +++ b/src/proxy/middleware/request/index.ts @@ -15,6 +15,7 @@ export { countPromptTokens } from "./preprocessors/count-prompt-tokens"; export { languageFilter } from "./preprocessors/language-filter"; export { setApiFormat } from "./preprocessors/set-api-format"; export { signAwsRequest } from "./preprocessors/sign-aws-request"; +export { signGcpRequest } from "./preprocessors/sign-vertex-ai-request"; export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload"; export { validateContextSize } from "./preprocessors/validate-context-size"; export { validateVision } from "./preprocessors/validate-vision"; diff --git a/src/proxy/middleware/request/onproxyreq/add-key.ts b/src/proxy/middleware/request/onproxyreq/add-key.ts index 902a080..27b2dc3 100644 --- a/src/proxy/middleware/request/onproxyreq/add-key.ts +++ b/src/proxy/middleware/request/onproxyreq/add-key.ts @@ -83,6 +83,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => { proxyReq.setHeader("api-key", azureKey); break; case "aws": + case "gcp": case "google-ai": throw new Error("add-key should not be used for this service."); default: diff --git a/src/proxy/middleware/request/onproxyreq/finalize-signed-request.ts b/src/proxy/middleware/request/onproxyreq/finalize-signed-request.ts index b885492..1380d1e 100644 --- a/src/proxy/middleware/request/onproxyreq/finalize-signed-request.ts +++ b/src/proxy/middleware/request/onproxyreq/finalize-signed-request.ts @@ -1,7 +1,7 @@ import type { HPMRequestCallback } from "../index"; /** - * For AWS/Azure/Google requests, the body is signed earlier in the request + * For AWS/GCP/Azure/Google requests, the body is signed earlier in the request * pipeline, before the proxy middleware. This function just assigns the path * and headers to the proxy request. */ diff --git a/src/proxy/middleware/request/preprocessors/sign-vertex-ai-request.ts b/src/proxy/middleware/request/preprocessors/sign-vertex-ai-request.ts new file mode 100644 index 0000000..5b9fcda --- /dev/null +++ b/src/proxy/middleware/request/preprocessors/sign-vertex-ai-request.ts @@ -0,0 +1,201 @@ +import express from "express"; +import crypto from "crypto"; +import { keyPool } from "../../../../shared/key-management"; +import { RequestPreprocessor } from "../index"; +import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas"; + +const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com"; + +export const signGcpRequest: RequestPreprocessor = async (req) => { + const serviceValid = req.service === "gcp"; + if (!serviceValid) { + throw new Error("addVertexAIKey called on invalid request"); + } + + if (!req.body?.model) { + throw new Error("You must specify a model with your request."); + } + + const { model, stream } = req.body; + req.key = keyPool.get(model, "gcp"); + + req.log.info({ key: req.key.hash, model }, "Assigned GCP key to request"); + + req.isStreaming = String(stream) === "true"; + + // TODO: This should happen in transform-outbound-payload.ts + // TODO: Support tools + let strippedParams: Record; + strippedParams = AnthropicV1MessagesSchema.pick({ + messages: true, + system: true, + max_tokens: true, + stop_sequences: true, + temperature: true, + top_k: true, + top_p: true, + stream: true, + }) + .strip() + .parse(req.body); + strippedParams.anthropic_version = "vertex-2023-10-16"; + + const [accessToken, credential] = await getAccessToken(req); + + const host = GCP_HOST.replace("%REGION%", credential.region); + // GCP doesn't use the anthropic-version header, but we set it to ensure the + // stream adapter selects the correct transformer. + req.headers["anthropic-version"] = "2023-06-01"; + + req.signedRequest = { + method: "POST", + protocol: "https:", + hostname: host, + path: `/v1/projects/${credential.projectId}/locations/${credential.region}/publishers/anthropic/models/${model}:streamRawPredict`, + headers: { + ["host"]: host, + ["content-type"]: "application/json", + ["authorization"]: `Bearer ${accessToken}`, + }, + body: JSON.stringify(strippedParams), + }; +}; + +async function getAccessToken( + req: express.Request +): Promise<[string, Credential]> { + // TODO: access token caching to reduce latency + const credential = getCredentialParts(req); + const signedJWT = await createSignedJWT( + credential.clientEmail, + credential.privateKey + ); + const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT); + if (accessToken === null) { + req.log.warn( + { key: req.key!.hash, jwtError }, + "Unable to get the access token" + ); + throw new Error("The access token is invalid."); + } + return [accessToken, credential]; +} + +async function createSignedJWT(email: string, pkey: string): Promise { + let cryptoKey = await crypto.subtle.importKey( + "pkcs8", + str2ab(atob(pkey)), + { + name: "RSASSA-PKCS1-v1_5", + hash: { name: "SHA-256" }, + }, + false, + ["sign"] + ); + + const authUrl = "https://www.googleapis.com/oauth2/v4/token"; + const issued = Math.floor(Date.now() / 1000); + const expires = issued + 600; + + const header = { + alg: "RS256", + typ: "JWT", + }; + + const payload = { + iss: email, + aud: authUrl, + iat: issued, + exp: expires, + scope: "https://www.googleapis.com/auth/cloud-platform", + }; + + const encodedHeader = urlSafeBase64Encode(JSON.stringify(header)); + const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload)); + + const unsignedToken = `${encodedHeader}.${encodedPayload}`; + + const signature = await crypto.subtle.sign( + "RSASSA-PKCS1-v1_5", + cryptoKey, + str2ab(unsignedToken) + ); + + const encodedSignature = urlSafeBase64Encode(signature); + return `${unsignedToken}.${encodedSignature}`; +} + +async function exchangeJwtForAccessToken( + signed_jwt: string +): Promise<[string | null, string]> { + const auth_url = "https://www.googleapis.com/oauth2/v4/token"; + const params = { + grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer", + assertion: signed_jwt, + }; + + const r = await fetch(auth_url, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: Object.entries(params) + .map(([k, v]) => `${k}=${v}`) + .join("&"), + }).then((res) => res.json()); + + if (r.access_token) { + return [r.access_token, ""]; + } + + return [null, JSON.stringify(r)]; +} + +function str2ab(str: string): ArrayBuffer { + const buffer = new ArrayBuffer(str.length); + const bufferView = new Uint8Array(buffer); + for (let i = 0; i < str.length; i++) { + bufferView[i] = str.charCodeAt(i); + } + return buffer; +} + +function urlSafeBase64Encode(data: string | ArrayBuffer): string { + let base64: string; + if (typeof data === "string") { + base64 = btoa( + encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => + String.fromCharCode(parseInt("0x" + p1, 16)) + ) + ); + } else { + base64 = btoa(String.fromCharCode(...new Uint8Array(data))); + } + return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, ""); +} + +type Credential = { + projectId: string; + clientEmail: string; + region: string; + privateKey: string; +}; + +function getCredentialParts(req: express.Request): Credential { + const [projectId, clientEmail, region, rawPrivateKey] = + req.key!.key.split(":"); + if (!projectId || !clientEmail || !region || !rawPrivateKey) { + req.log.error( + { key: req.key!.hash }, + "GCP_CREDENTIALS isn't correctly formatted; refer to the docs" + ); + throw new Error("The key assigned to this request is invalid."); + } + + const privateKey = rawPrivateKey + .replace( + /-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, + "" + ) + .trim(); + + return { projectId, clientEmail, region, privateKey }; +} diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index 1774953..52ec043 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -186,6 +186,13 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( throw new HttpError(statusCode, parseError.message); } + const service = req.key!.service; + if (service === "gcp") { + if (Array.isArray(errorPayload)) { + errorPayload = errorPayload[0]; + } + } + const errorType = errorPayload.error?.code || errorPayload.error?.type || @@ -199,11 +206,15 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( // TODO: split upstream error handling into separate modules for each service, // this is out of control. - const service = req.key!.service; if (service === "aws") { // Try to standardize the error format for AWS errorPayload.error = { message: errorPayload.message, type: errorType }; delete errorPayload.message; + } else if (service === "gcp") { + // Try to standardize the error format for GCP + if (errorPayload.error?.code) { // GCP Error + errorPayload.error = { message: errorPayload.error.message, type: errorPayload.error.status || errorPayload.error.code }; + } } if (statusCode === 400) { @@ -225,6 +236,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( break; case "anthropic": case "aws": + case "gcp": await handleAnthropicAwsBadRequestError(req, errorPayload); break; case "google-ai": @@ -280,6 +292,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( default: errorPayload.proxy_note = `Received 403 error. Key may be invalid.`; } + return; + case "gcp": + keyPool.disable(req.key!, "revoked"); + errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`; + return; } } else if (statusCode === 429) { switch (service) { @@ -292,6 +309,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( case "aws": await handleAwsRateLimitError(req, errorPayload); break; + case "gcp": + await handleGcpRateLimitError(req, errorPayload); + break; case "azure": case "mistral-ai": await handleAzureRateLimitError(req, errorPayload); @@ -328,6 +348,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( case "aws": errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`; break; + case "gcp": + errorPayload.proxy_note = `The requested GCP resource might not exist, or the key might not have access to it.`; + break; case "azure": errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`; break; @@ -434,6 +457,19 @@ async function handleAwsRateLimitError( } } +async function handleGcpRateLimitError( + req: Request, + errorPayload: ProxiedErrorPayload +) { + if (errorPayload.error?.type === "RESOURCE_EXHAUSTED") { + keyPool.markRateLimited(req.key!); + await reenqueueRequest(req); + throw new RetryableError("GCP rate-limited request re-enqueued."); + } else { + errorPayload.proxy_note = `Unrecognized 429 Too Many Requests error from GCP.`; + } +} + async function handleOpenAIRateLimitError( req: Request, errorPayload: ProxiedErrorPayload diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index 281cf55..9932ef2 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -7,6 +7,7 @@ import { anthropic } from "./anthropic"; import { googleAI } from "./google-ai"; import { mistralAI } from "./mistral-ai"; import { aws } from "./aws"; +import { gcp } from "./gcp"; import { azure } from "./azure"; import { sendErrorToClient } from "./middleware/response/error-generator"; @@ -36,6 +37,7 @@ proxyRouter.use("/anthropic", addV1, anthropic); proxyRouter.use("/google-ai", addV1, googleAI); proxyRouter.use("/mistral-ai", addV1, mistralAI); proxyRouter.use("/aws/claude", addV1, aws); +proxyRouter.use("/gcp/claude", addV1, gcp); proxyRouter.use("/azure/openai", addV1, azure); // Redirect browser requests to the homepage. proxyRouter.get("*", (req, res, next) => { diff --git a/src/service-info.ts b/src/service-info.ts index ea62d58..5f3752c 100644 --- a/src/service-info.ts +++ b/src/service-info.ts @@ -2,6 +2,7 @@ import { config, listConfig } from "./config"; import { AnthropicKey, AwsBedrockKey, + GcpKey, AzureOpenAIKey, GoogleAIKey, keyPool, @@ -11,6 +12,7 @@ import { AnthropicModelFamily, assertIsKnownModelFamily, AwsBedrockModelFamily, + GcpModelFamily, AzureOpenAIModelFamily, GoogleAIModelFamily, LLM_SERVICES, @@ -40,6 +42,7 @@ const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey => const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey => k.service === "mistral-ai"; const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws"; +const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp"; /** Stats aggregated across all keys for a given service. */ type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs"; @@ -52,7 +55,11 @@ type ModelAggregates = { pozzed?: number; awsLogged?: number; awsSonnet?: number; + awsSonnet35?: number; awsHaiku?: number; + gcpSonnet?: number; + gcpSonnet35?: number; + gcpHaiku?: number; queued: number; queueTime: string; tokens: number; @@ -87,6 +94,12 @@ type AnthropicInfo = BaseFamilyInfo & { type AwsInfo = BaseFamilyInfo & { privacy?: string; sonnetKeys?: number; + sonnet35Keys?: number; + haikuKeys?: number; +}; +type GcpInfo = BaseFamilyInfo & { + sonnetKeys?: number; + sonnet35Keys?: number; haikuKeys?: number; }; @@ -101,6 +114,7 @@ export type ServiceInfo = { "google-ai"?: string; "mistral-ai"?: string; aws?: string; + gcp?: string; azure?: string; "openai-image"?: string; "azure-image"?: string; @@ -114,6 +128,7 @@ export type ServiceInfo = { } & { [f in OpenAIModelFamily]?: OpenAIInfo } & { [f in AnthropicModelFamily]?: AnthropicInfo; } & { [f in AwsBedrockModelFamily]?: AwsInfo } + & { [f in GcpModelFamily]?: GcpInfo } & { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; } & { [f in GoogleAIModelFamily]?: BaseFamilyInfo } & { [f in MistralAIModelFamily]?: BaseFamilyInfo }; @@ -151,6 +166,9 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { aws: { aws: `%BASE%/aws/claude`, }, + gcp: { + gcp: `%BASE%/gcp/claude`, + }, azure: { azure: `%BASE%/azure/openai`, "azure-image": `%BASE%/azure/openai`, @@ -305,6 +323,7 @@ function addKeyToAggregates(k: KeyPoolKey) { k.service === "mistral-ai" ? 1 : 0 ); increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0); + increment(serviceStats, "gcp__keys", k.service === "gcp" ? 1 : 0); increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0); let sumTokens = 0; @@ -396,6 +415,7 @@ function addKeyToAggregates(k: KeyPoolKey) { increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); }); increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0); + increment(modelStats, `aws-claude__awsSonnet35`, k.sonnet35Enabled ? 1 : 0); increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0); // Ignore revoked keys for aws logging stats, but include keys where the @@ -405,6 +425,21 @@ function addKeyToAggregates(k: KeyPoolKey) { increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0); break; } + case "gcp": { + if (!keyIsGcpKey(k)) throw new Error("Invalid key type"); + k.modelFamilies.forEach((f) => { + const tokens = k[`${f}Tokens`]; + sumTokens += tokens; + sumCost += getTokenCostUsd(f, tokens); + increment(modelStats, `${f}__tokens`, tokens); + increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); + increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); + }); + increment(modelStats, `gcp-claude__gcpSonnet`, k.sonnetEnabled ? 1 : 0); + increment(modelStats, `gcp-claude__gcpSonnet35`, k.sonnet35Enabled ? 1 : 0); + increment(modelStats, `gcp-claude__gcpHaiku`, k.haikuEnabled ? 1 : 0); + break; + } default: assertNever(k.service); } @@ -416,7 +451,7 @@ function addKeyToAggregates(k: KeyPoolKey) { function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { const tokens = modelStats.get(`${family}__tokens`) || 0; const cost = getTokenCostUsd(family, tokens); - let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = { + let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = { usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`, activeKeys: modelStats.get(`${family}__active`) || 0, revokedKeys: modelStats.get(`${family}__revoked`) || 0, @@ -446,6 +481,7 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { case "aws": if (family === "aws-claude") { info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0; + info.sonnet35Keys = modelStats.get(`${family}__awsSonnet35`) || 0; info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0; const logged = modelStats.get(`${family}__awsLogged`) || 0; if (logged > 0) { @@ -455,6 +491,13 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { } } break; + case "gcp": + if (family === "gcp-claude") { + info.sonnetKeys = modelStats.get(`${family}__gcpSonnet`) || 0; + info.sonnet35Keys = modelStats.get(`${family}__gcpSonnet35`) || 0; + info.haikuKeys = modelStats.get(`${family}__gcpHaiku`) || 0; + } + break; } } diff --git a/src/shared/key-management/anthropic/checker.ts b/src/shared/key-management/anthropic/checker.ts index 5c8c968..1727178 100644 --- a/src/shared/key-management/anthropic/checker.ts +++ b/src/shared/key-management/anthropic/checker.ts @@ -122,7 +122,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase { { key: key.hash, error: error.message }, "Network error while checking key; trying this key again in a minute." ); - const oneMinute = 10 * 1000; + const oneMinute = 60 * 1000; const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute); this.updateKey(key.hash, { lastChecked: next }); } diff --git a/src/shared/key-management/gcp/checker.ts b/src/shared/key-management/gcp/checker.ts new file mode 100644 index 0000000..5995065 --- /dev/null +++ b/src/shared/key-management/gcp/checker.ts @@ -0,0 +1,277 @@ +import axios, { AxiosError } from "axios"; +import crypto from "crypto"; +import { KeyCheckerBase } from "../key-checker-base"; +import type { GcpKey, GcpKeyProvider } from "./provider"; +import { GcpModelFamily } from "../../models"; + +const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds +const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes +const GCP_HOST = + process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com"; +const POST_STREAM_RAW_URL = (project: string, region: string, model: string) => + `https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`; +const TEST_MESSAGES = [ + { role: "user", content: "Hi!" }, + { role: "assistant", content: "Hello!" }, +]; + +type UpdateFn = typeof GcpKeyProvider.prototype.update; + +export class GcpKeyChecker extends KeyCheckerBase { + constructor(keys: GcpKey[], updateKey: UpdateFn) { + super(keys, { + service: "gcp", + keyCheckPeriod: KEY_CHECK_PERIOD, + minCheckInterval: MIN_CHECK_INTERVAL, + updateKey, + }); + } + + protected async testKeyOrFail(key: GcpKey) { + let checks: Promise[] = []; + const isInitialCheck = !key.lastChecked; + if (isInitialCheck) { + checks = [ + this.invokeModel("claude-3-haiku@20240307", key, true), + this.invokeModel("claude-3-sonnet@20240229", key, true), + this.invokeModel("claude-3-opus@20240229", key, true), + this.invokeModel("claude-3-5-sonnet@20240620", key, true), + ]; + + const [sonnet, haiku, opus, sonnet35] = + await Promise.all(checks); + + this.log.debug( + { key: key.hash, sonnet, haiku, opus, sonnet35 }, + "GCP model initial tests complete." + ); + + const families: GcpModelFamily[] = []; + if (sonnet || sonnet35 || haiku) families.push("gcp-claude"); + if (opus) families.push("gcp-claude-opus"); + + if (families.length === 0) { + this.log.warn( + { key: key.hash }, + "Key does not have access to any models; disabling." + ); + return this.updateKey(key.hash, { isDisabled: true }); + } + + this.updateKey(key.hash, { + sonnetEnabled: sonnet, + haikuEnabled: haiku, + sonnet35Enabled: sonnet35, + modelFamilies: families, + }); + } else { + if (key.haikuEnabled) { + await this.invokeModel("claude-3-haiku@20240307", key, false) + } else if (key.sonnetEnabled) { + await this.invokeModel("claude-3-sonnet@20240229", key, false) + } else if (key.sonnet35Enabled) { + await this.invokeModel("claude-3-5-sonnet@20240620", key, false) + } else { + await this.invokeModel("claude-3-opus@20240229", key, false) + } + + this.updateKey(key.hash, { lastChecked: Date.now() }); + this.log.debug( + { key: key.hash}, + "GCP key check complete." + ); + } + + this.log.info( + { + key: key.hash, + families: key.modelFamilies, + }, + "Checked key." + ); + } + + protected handleAxiosError(key: GcpKey, error: AxiosError) { + if (error.response && GcpKeyChecker.errorIsGcpError(error)) { + const { status, data } = error.response; + if (status === 400 || status === 401 || status === 403) { + this.log.warn( + { key: key.hash, error: data }, + "Key is invalid or revoked. Disabling key." + ); + this.updateKey(key.hash, { isDisabled: true, isRevoked: true }); + } else if (status === 429) { + this.log.warn( + { key: key.hash, error: data }, + "Key is rate limited. Rechecking in a minute." + ); + const next = Date.now() - (KEY_CHECK_PERIOD - 60 * 1000); + this.updateKey(key.hash, { lastChecked: next }); + } else { + this.log.error( + { key: key.hash, status, error: data }, + "Encountered unexpected error status while checking key. This may indicate a change in the API; please report this." + ); + this.updateKey(key.hash, { lastChecked: Date.now() }); + } + return; + } + const { response, cause } = error; + const { headers, status, data } = response ?? {}; + this.log.error( + { key: key.hash, status, headers, data, cause, error: error.message }, + "Network error while checking key; trying this key again in a minute." + ); + const oneMinute = 60 * 1000; + const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute); + this.updateKey(key.hash, { lastChecked: next }); + } + + /** + * Attempt to invoke the given model with the given key. Returns true if the + * key has access to the model, false if it does not. Throws an error if the + * key is disabled. + */ + private async invokeModel(model: string, key: GcpKey, initial: boolean) { + const creds = GcpKeyChecker.getCredentialsFromKey(key); + const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey) + const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT) + if (accessToken === null) { + this.log.warn( + { key: key.hash, jwtError }, + "Unable to get the access token" + ); + return false; + } + const payload = { + max_tokens: 1, + messages: TEST_MESSAGES, + anthropic_version: "vertex-2023-10-16", + }; + const { data, status } = await axios.post( + POST_STREAM_RAW_URL(creds.projectId, creds.region, model), + payload, + { + headers: GcpKeyChecker.getRequestHeaders(accessToken), + validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300 + } + ); + this.log.debug({ key: key.hash, data }, "Response from GCP"); + + if (initial) { + return (status >= 200 && status < 300) || (status === 429 || status === 529); + } + + return true; + } + + static errorIsGcpError(error: AxiosError): error is AxiosError { + const data = error.response?.data as any; + if (Array.isArray(data)) { + return data.length > 0 && data[0]?.error?.message; + } else { + return data?.error?.message; + } + } + + static async createSignedJWT(email: string, pkey: string): Promise { + let cryptoKey = await crypto.subtle.importKey( + "pkcs8", + GcpKeyChecker.str2ab(atob(pkey)), + { + name: "RSASSA-PKCS1-v1_5", + hash: { name: "SHA-256" }, + }, + false, + ["sign"] + ); + + const authUrl = "https://www.googleapis.com/oauth2/v4/token"; + const issued = Math.floor(Date.now() / 1000); + const expires = issued + 600; + + const header = { + alg: "RS256", + typ: "JWT", + }; + + const payload = { + iss: email, + aud: authUrl, + iat: issued, + exp: expires, + scope: "https://www.googleapis.com/auth/cloud-platform", + }; + + const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header)); + const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload)); + + const unsignedToken = `${encodedHeader}.${encodedPayload}`; + + const signature = await crypto.subtle.sign( + "RSASSA-PKCS1-v1_5", + cryptoKey, + GcpKeyChecker.str2ab(unsignedToken) + ); + + const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature); + return `${unsignedToken}.${encodedSignature}`; + } + + static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> { + const auth_url = "https://www.googleapis.com/oauth2/v4/token"; + const params = { + grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer", + assertion: signed_jwt, + }; + + const r = await fetch(auth_url, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: Object.entries(params) + .map(([k, v]) => `${k}=${v}`) + .join("&"), + }).then((res) => res.json()); + + if (r.access_token) { + return [r.access_token, ""]; + } + + return [null, JSON.stringify(r)]; + } + + static str2ab(str: string): ArrayBuffer { + const buffer = new ArrayBuffer(str.length); + const bufferView = new Uint8Array(buffer); + for (let i = 0; i < str.length; i++) { + bufferView[i] = str.charCodeAt(i); + } + return buffer; + } + + static urlSafeBase64Encode(data: string | ArrayBuffer): string { + let base64: string; + if (typeof data === "string") { + base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16)))); + } else { + base64 = btoa(String.fromCharCode(...new Uint8Array(data))); + } + return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, ""); + } + + static getRequestHeaders(accessToken: string) { + return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" }; + } + + static getCredentialsFromKey(key: GcpKey) { + const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":"); + if (!projectId || !clientEmail || !region || !rawPrivateKey) { + throw new Error("Invalid GCP key"); + } + const privateKey = rawPrivateKey + .replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '') + .trim(); + + return { projectId, clientEmail, region, privateKey }; + } +} diff --git a/src/shared/key-management/gcp/provider.ts b/src/shared/key-management/gcp/provider.ts new file mode 100644 index 0000000..8e9c9ab --- /dev/null +++ b/src/shared/key-management/gcp/provider.ts @@ -0,0 +1,242 @@ +import crypto from "crypto"; +import { Key, KeyProvider } from ".."; +import { config } from "../../../config"; +import { logger } from "../../../logger"; +import { GcpModelFamily, getGcpModelFamily } from "../../models"; +import { GcpKeyChecker } from "./checker"; +import { PaymentRequiredError } from "../../errors"; + +type GcpKeyUsage = { + [K in GcpModelFamily as `${K}Tokens`]: number; +}; + +export interface GcpKey extends Key, GcpKeyUsage { + readonly service: "gcp"; + readonly modelFamilies: GcpModelFamily[]; + /** The time at which this key was last rate limited. */ + rateLimitedAt: number; + /** The time until which this key is rate limited. */ + rateLimitedUntil: number; + sonnetEnabled: boolean; + haikuEnabled: boolean; + sonnet35Enabled: boolean; +} + +/** + * Upon being rate limited, a key will be locked out for this many milliseconds + * while we wait for other concurrent requests to finish. + */ +const RATE_LIMIT_LOCKOUT = 4000; +/** + * Upon assigning a key, we will wait this many milliseconds before allowing it + * to be used again. This is to prevent the queue from flooding a key with too + * many requests while we wait to learn whether previous ones succeeded. + */ +const KEY_REUSE_DELAY = 500; + +export class GcpKeyProvider implements KeyProvider { + readonly service = "gcp"; + + private keys: GcpKey[] = []; + private checker?: GcpKeyChecker; + private log = logger.child({ module: "key-provider", service: this.service }); + + constructor() { + const keyConfig = config.gcpCredentials?.trim(); + if (!keyConfig) { + this.log.warn( + "GCP_CREDENTIALS is not set. GCP API will not be available." + ); + return; + } + let bareKeys: string[]; + bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; + for (const key of bareKeys) { + const newKey: GcpKey = { + key, + service: this.service, + modelFamilies: ["gcp-claude"], + isDisabled: false, + isRevoked: false, + promptCount: 0, + lastUsed: 0, + rateLimitedAt: 0, + rateLimitedUntil: 0, + hash: `gcp-${crypto + .createHash("sha256") + .update(key) + .digest("hex") + .slice(0, 8)}`, + lastChecked: 0, + sonnetEnabled: true, + haikuEnabled: false, + sonnet35Enabled: false, + ["gcp-claudeTokens"]: 0, + ["gcp-claude-opusTokens"]: 0, + }; + this.keys.push(newKey); + } + this.log.info({ keyCount: this.keys.length }, "Loaded GCP keys."); + } + + public init() { + if (config.checkKeys) { + this.checker = new GcpKeyChecker(this.keys, this.update.bind(this)); + this.checker.start(); + } + } + + public list() { + return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); + } + + public get(model: string) { + const neededFamily = getGcpModelFamily(model); + + // this is a horrible mess + // each of these should be separate model families, but adding model + // families is not low enough friction for the rate at which gcp claude + // model variants are added. + const needsSonnet35 = + model.includes("claude-3-5-sonnet") && neededFamily === "gcp-claude"; + const needsSonnet = + !needsSonnet35 && + model.includes("sonnet") && + neededFamily === "gcp-claude"; + const needsHaiku = model.includes("haiku") && neededFamily === "gcp-claude"; + + const availableKeys = this.keys.filter((k) => { + return ( + !k.isDisabled && + (k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under gcp-claude, while opus is not + (k.haikuEnabled || !needsHaiku) && + (k.sonnet35Enabled || !needsSonnet35) && + k.modelFamilies.includes(neededFamily) + ); + }); + + this.log.debug( + { + model, + neededFamily, + needsSonnet, + needsHaiku, + needsSonnet35, + availableKeys: availableKeys.length, + totalKeys: this.keys.length, + }, + "Selecting GCP key" + ); + + if (availableKeys.length === 0) { + throw new PaymentRequiredError( + `No GCP keys available for model ${model}` + ); + } + + // (largely copied from the OpenAI provider, without trial key support) + // Select a key, from highest priority to lowest priority: + // 1. Keys which are not rate limited + // a. If all keys were rate limited recently, select the least-recently + // rate limited key. + // 3. Keys which have not been used in the longest time + + const now = Date.now(); + + const keysByPriority = availableKeys.sort((a, b) => { + const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; + const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; + + if (aRateLimited && !bRateLimited) return 1; + if (!aRateLimited && bRateLimited) return -1; + if (aRateLimited && bRateLimited) { + return a.rateLimitedAt - b.rateLimitedAt; + } + + return a.lastUsed - b.lastUsed; + }); + + const selectedKey = keysByPriority[0]; + selectedKey.lastUsed = now; + this.throttle(selectedKey.hash); + return { ...selectedKey }; + } + + public disable(key: GcpKey) { + const keyFromPool = this.keys.find((k) => k.hash === key.hash); + if (!keyFromPool || keyFromPool.isDisabled) return; + keyFromPool.isDisabled = true; + this.log.warn({ key: key.hash }, "Key disabled"); + } + + public update(hash: string, update: Partial) { + const keyFromPool = this.keys.find((k) => k.hash === hash)!; + Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); + } + + public available() { + return this.keys.filter((k) => !k.isDisabled).length; + } + + public incrementUsage(hash: string, model: string, tokens: number) { + const key = this.keys.find((k) => k.hash === hash); + if (!key) return; + key.promptCount++; + key[`${getGcpModelFamily(model)}Tokens`] += tokens; + } + + public getLockoutPeriod() { + // TODO: same exact behavior for three providers, should be refactored + const activeKeys = this.keys.filter((k) => !k.isDisabled); + // Don't lock out if there are no keys available or the queue will stall. + // Just let it through so the add-key middleware can throw an error. + if (activeKeys.length === 0) return 0; + + const now = Date.now(); + const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); + const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; + + if (anyNotRateLimited) return 0; + + // If all keys are rate-limited, return time until the first key is ready. + return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); + } + + /** + * This is called when we receive a 429, which means there are already five + * concurrent requests running on this key. We don't have any information on + * when these requests will resolve, so all we can do is wait a bit and try + * again. We will lock the key for 2 seconds after getting a 429 before + * retrying in order to give the other requests a chance to finish. + */ + public markRateLimited(keyHash: string) { + this.log.debug({ key: keyHash }, "Key rate limited"); + const key = this.keys.find((k) => k.hash === keyHash)!; + const now = Date.now(); + key.rateLimitedAt = now; + key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT; + } + + public recheck() { + this.keys.forEach(({ hash }) => + this.update(hash, { lastChecked: 0, isDisabled: false, isRevoked: false }) + ); + this.checker?.scheduleNextCheck(); + } + + /** + * Applies a short artificial delay to the key upon dequeueing, in order to + * prevent it from being immediately assigned to another request before the + * current one can be dispatched. + **/ + private throttle(hash: string) { + const now = Date.now(); + const key = this.keys.find((k) => k.hash === hash)!; + + const currentRateLimit = key.rateLimitedUntil; + const nextRateLimit = now + KEY_REUSE_DELAY; + + key.rateLimitedAt = now; + key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit); + } +} diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 5e43e57..67dfad4 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -63,4 +63,5 @@ export { AnthropicKey } from "./anthropic/provider"; export { OpenAIKey } from "./openai/provider"; export { GoogleAIKey } from "././google-ai/provider"; export { AwsBedrockKey } from "./aws/provider"; +export { GcpKey } from "./gcp/provider"; export { AzureOpenAIKey } from "./azure/provider"; diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 2ae5604..9e041c4 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -10,6 +10,7 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { GoogleAIKeyProvider } from "./google-ai/provider"; import { AwsBedrockKeyProvider } from "./aws/provider"; +import { GcpKeyProvider } from "./gcp/provider"; import { AzureOpenAIKeyProvider } from "./azure/provider"; import { MistralAIKeyProvider } from "./mistral-ai/provider"; @@ -27,6 +28,7 @@ export class KeyPool { this.keyProviders.push(new GoogleAIKeyProvider()); this.keyProviders.push(new MistralAIKeyProvider()); this.keyProviders.push(new AwsBedrockKeyProvider()); + this.keyProviders.push(new GcpKeyProvider()); this.keyProviders.push(new AzureOpenAIKeyProvider()); } @@ -128,7 +130,11 @@ export class KeyPool { return "openai"; } else if (model.startsWith("claude-")) { // https://console.anthropic.com/docs/api/reference#parameters - return "anthropic"; + if (!model.includes('@')) { + return "anthropic"; + } else { + return "gcp"; + } } else if (model.includes("gemini")) { // https://developers.generativeai.google.com/models/language return "google-ai"; diff --git a/src/shared/models.ts b/src/shared/models.ts index e3567cd..2004374 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -5,7 +5,7 @@ import type { Request } from "express"; /** * The service that a model is hosted on. Distinct from `APIFormat` because some - * services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure). + * services have interoperable APIs (eg Anthropic/AWS/GCP, OpenAI/Azure). */ export type LLMService = | "openai" @@ -13,6 +13,7 @@ export type LLMService = | "google-ai" | "mistral-ai" | "aws" + | "gcp" | "azure"; export type OpenAIModelFamily = @@ -32,6 +33,7 @@ export type MistralAIModelFamily = // correspond to specific models. consider them rough pricing tiers. "mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large"; export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus"; +export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus"; export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`; export type ModelFamily = | OpenAIModelFamily @@ -39,6 +41,7 @@ export type ModelFamily = | GoogleAIModelFamily | MistralAIModelFamily | AwsBedrockModelFamily + | GcpModelFamily | AzureOpenAIModelFamily; export const MODEL_FAMILIES = (( @@ -61,6 +64,8 @@ export const MODEL_FAMILIES = (( "mistral-large", "aws-claude", "aws-claude-opus", + "gcp-claude", + "gcp-claude-opus", "azure-turbo", "azure-gpt4", "azure-gpt4-32k", @@ -77,6 +82,7 @@ export const LLM_SERVICES = (( "google-ai", "mistral-ai", "aws", + "gcp", "azure", ] as const); @@ -93,6 +99,8 @@ export const MODEL_FAMILY_SERVICE: { "claude-opus": "anthropic", "aws-claude": "aws", "aws-claude-opus": "aws", + "gcp-claude": "gcp", + "gcp-claude-opus": "gcp", "azure-turbo": "azure", "azure-gpt4": "azure", "azure-gpt4-32k": "azure", @@ -176,6 +184,11 @@ export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily { return "aws-claude"; } +export function getGcpModelFamily(model: string): GcpModelFamily { + if (model.includes("opus")) return "gcp-claude-opus"; + return "gcp-claude"; +} + export function getAzureOpenAIModelFamily( model: string, defaultFamily: AzureOpenAIModelFamily = "azure-gpt4" @@ -210,10 +223,12 @@ export function getModelFamilyForRequest(req: Request): ModelFamily { const model = req.body.model ?? "gpt-3.5-turbo"; let modelFamily: ModelFamily; - // Weird special case for AWS/Azure because they serve multiple models from + // Weird special case for AWS/GCP/Azure because they serve multiple models from // different vendors, even if currently only one is supported. if (req.service === "aws") { modelFamily = getAwsBedrockModelFamily(model); + } else if (req.service === "gcp") { + modelFamily = getGcpModelFamily(model); } else if (req.service === "azure") { modelFamily = getAzureOpenAIModelFamily(model); } else { diff --git a/src/shared/stats.ts b/src/shared/stats.ts index def4b7c..f631397 100644 --- a/src/shared/stats.ts +++ b/src/shared/stats.ts @@ -30,10 +30,12 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) { cost = 0.00001; break; case "aws-claude": + case "gcp-claude": case "claude": cost = 0.000008; break; case "aws-claude-opus": + case "gcp-claude-opus": case "claude-opus": cost = 0.000015; break; diff --git a/src/shared/users/user-store.ts b/src/shared/users/user-store.ts index 8ddf7d3..9718c90 100644 --- a/src/shared/users/user-store.ts +++ b/src/shared/users/user-store.ts @@ -13,6 +13,7 @@ import { v4 as uuid } from "uuid"; import { config, getFirebaseApp } from "../../config"; import { getAwsBedrockModelFamily, + getGcpModelFamily, getAzureOpenAIModelFamily, getClaudeModelFamily, getGoogleAIModelFamily, @@ -417,6 +418,7 @@ function getModelFamilyForQuotaUsage( // differentiate between Azure and OpenAI variants of the same model. if (model.includes("azure")) return getAzureOpenAIModelFamily(model); if (model.includes("anthropic.")) return getAwsBedrockModelFamily(model); + if (model.startsWith("claude-") && model.includes("@")) return getGcpModelFamily(model); switch (api) { case "openai":