diff --git a/src/proxy/middleware/request/mutators/sign-vertex-ai-request.ts b/src/proxy/middleware/request/mutators/sign-vertex-ai-request.ts index 1eb90c8..69dcda2 100644 --- a/src/proxy/middleware/request/mutators/sign-vertex-ai-request.ts +++ b/src/proxy/middleware/request/mutators/sign-vertex-ai-request.ts @@ -1,11 +1,8 @@ -import { Request } from "express"; -import crypto from "crypto"; import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas"; -import { keyPool } from "../../../../shared/key-management"; -import { getAxiosInstance } from "../../../../shared/network"; +import { GcpKey, keyPool } from "../../../../shared/key-management"; import { ProxyReqMutator } from "../index"; - -const axios = getAxiosInstance(); +import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "../../../../shared/key-management/gcp/oauth"; +import { credential } from "firebase-admin"; const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com"; @@ -21,9 +18,18 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => { } const { model } = req.body; - const key = keyPool.get(model, "gcp"); + const key: GcpKey = keyPool.get(model, "gcp") as GcpKey; manager.setKey(key); + if (!key.accessToken || Date.now() > key.accessTokenExpiresAt) { + const [token, durationSec] = await refreshGcpAccessToken(key); + keyPool.update(key, { + accessToken: token, + accessTokenExpiresAt: Date.now() + durationSec * 1000 * 0.95, + } as GcpKey); + } + + req.log.info({ key: key.hash, model }, "Assigned GCP key to request"); // TODO: This should happen in transform-outbound-payload.ts @@ -43,7 +49,7 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => { .parse(req.body); strippedParams.anthropic_version = "vertex-2023-10-16"; - const [accessToken, credential] = await getAccessToken(req); + const credential = await getCredentialsFromGcpKey(key); const host = GCP_HOST.replace("%REGION%", credential.region); // GCP doesn't use the anthropic-version header, but we set it to ensure the @@ -58,151 +64,8 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => { headers: { ["host"]: host, ["content-type"]: "application/json", - ["authorization"]: `Bearer ${accessToken}`, + ["authorization"]: `Bearer ${key.accessToken}`, }, body: JSON.stringify(strippedParams), }); }; - -async function getAccessToken( - req: Readonly -): Promise<[string, Credential]> { - // TODO: access token caching to reduce latency - const credential = getCredentialParts(req); - const signedJWT = await createSignedJWT( - credential.clientEmail, - credential.privateKey - ); - const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT); - if (accessToken === null) { - req.log.warn( - { key: req.key!.hash, jwtError }, - "Unable to get the access token" - ); - throw new Error("The access token is invalid."); - } - return [accessToken, credential]; -} - -async function createSignedJWT(email: string, pkey: string): Promise { - let cryptoKey = await crypto.subtle.importKey( - "pkcs8", - str2ab(atob(pkey)), - { - name: "RSASSA-PKCS1-v1_5", - hash: { name: "SHA-256" }, - }, - false, - ["sign"] - ); - - const authUrl = "https://www.googleapis.com/oauth2/v4/token"; - const issued = Math.floor(Date.now() / 1000); - const expires = issued + 600; - - const header = { - alg: "RS256", - typ: "JWT", - }; - - const payload = { - iss: email, - aud: authUrl, - iat: issued, - exp: expires, - scope: "https://www.googleapis.com/auth/cloud-platform", - }; - - const encodedHeader = urlSafeBase64Encode(JSON.stringify(header)); - const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload)); - - const unsignedToken = `${encodedHeader}.${encodedPayload}`; - - const signature = await crypto.subtle.sign( - "RSASSA-PKCS1-v1_5", - cryptoKey, - str2ab(unsignedToken) - ); - - const encodedSignature = urlSafeBase64Encode(signature); - return `${unsignedToken}.${encodedSignature}`; -} - -async function exchangeJwtForAccessToken( - signedJwt: string -): Promise<[string | null, string]> { - const authUrl = "https://www.googleapis.com/oauth2/v4/token"; - const params = { - grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer", - assertion: signedJwt, - }; - - try { - const response = await axios.post(authUrl, params, { - headers: { "Content-Type": "application/x-www-form-urlencoded" }, - }); - - if (response.data.access_token) { - return [response.data.access_token, ""]; - } else { - return [null, JSON.stringify(response.data)]; - } - } catch (error) { - if ("response" in error && "data" in error.response) { - return [null, JSON.stringify(error.response.data)]; - } else { - return [null, "An unexpected error occurred"]; - } - } -} - -function str2ab(str: string): ArrayBuffer { - const buffer = new ArrayBuffer(str.length); - const bufferView = new Uint8Array(buffer); - for (let i = 0; i < str.length; i++) { - bufferView[i] = str.charCodeAt(i); - } - return buffer; -} - -function urlSafeBase64Encode(data: string | ArrayBuffer): string { - let base64: string; - if (typeof data === "string") { - base64 = btoa( - encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => - String.fromCharCode(parseInt("0x" + p1, 16)) - ) - ); - } else { - base64 = btoa(String.fromCharCode(...new Uint8Array(data))); - } - return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, ""); -} - -type Credential = { - projectId: string; - clientEmail: string; - region: string; - privateKey: string; -}; - -function getCredentialParts(req: Readonly): Credential { - const [projectId, clientEmail, region, rawPrivateKey] = - req.key!.key.split(":"); - if (!projectId || !clientEmail || !region || !rawPrivateKey) { - req.log.error( - { key: req.key!.hash }, - "GCP_CREDENTIALS isn't correctly formatted; refer to the docs" - ); - throw new Error("The key assigned to this request is invalid."); - } - - const privateKey = rawPrivateKey - .replace( - /-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, - "" - ) - .trim(); - - return { projectId, clientEmail, region, privateKey }; -} diff --git a/src/proxy/middleware/request/proxy-middleware-factory.ts b/src/proxy/middleware/request/proxy-middleware-factory.ts index 5de2778..629200c 100644 --- a/src/proxy/middleware/request/proxy-middleware-factory.ts +++ b/src/proxy/middleware/request/proxy-middleware-factory.ts @@ -125,6 +125,9 @@ function pinoLoggerPlugin(proxyServer: ProxyServer) { target: `${protocol}//${host}${path}`, status: proxyRes.statusCode, contentType: proxyRes.headers["content-type"], + contentEncoding: proxyRes.headers["content-encoding"], + contentLength: proxyRes.headers["content-length"], + transferEncoding: proxyRes.headers["transfer-encoding"], }, "Got response from upstream API." ); diff --git a/src/proxy/middleware/response/compression.ts b/src/proxy/middleware/response/compression.ts new file mode 100644 index 0000000..7581f35 --- /dev/null +++ b/src/proxy/middleware/response/compression.ts @@ -0,0 +1,36 @@ +import util from "util"; +import zlib from "zlib"; +import { PassThrough } from "stream"; + +const BUFFER_DECODER_MAP = { + gzip: util.promisify(zlib.gunzip), + deflate: util.promisify(zlib.inflate), + br: util.promisify(zlib.brotliDecompress), + text: (data: Buffer) => data, +}; + +const STREAM_DECODER_MAP = { + gzip: zlib.createGunzip, + deflate: zlib.createInflate, + br: zlib.createBrotliDecompress, + text: () => new PassThrough(), +}; + +type SupportedContentEncoding = keyof typeof BUFFER_DECODER_MAP; +const isSupportedContentEncoding = ( + encoding: string +): encoding is SupportedContentEncoding => encoding in BUFFER_DECODER_MAP; + +export async function decompressBuffer(buf: Buffer, encoding: string = "text") { + if (isSupportedContentEncoding(encoding)) { + return (await BUFFER_DECODER_MAP[encoding](buf)).toString(); + } + throw new Error(`Unsupported content-encoding: ${encoding}`); +} + +export function getStreamDecompressor(encoding: string = "text") { + if (isSupportedContentEncoding(encoding)) { + return STREAM_DECODER_MAP[encoding](); + } + throw new Error(`Unsupported content-encoding: ${encoding}`); +} diff --git a/src/proxy/middleware/response/handle-blocking-response.ts b/src/proxy/middleware/response/handle-blocking-response.ts index 6253a49..8128c0e 100644 --- a/src/proxy/middleware/response/handle-blocking-response.ts +++ b/src/proxy/middleware/response/handle-blocking-response.ts @@ -1,20 +1,6 @@ -import { Request, Response } from "express"; -import util from "util"; -import zlib from "zlib"; import { sendProxyError } from "../common"; import type { RawResponseBodyHandler } from "./index"; - -const DECODER_MAP = { - gzip: util.promisify(zlib.gunzip), - deflate: util.promisify(zlib.inflate), - br: util.promisify(zlib.brotliDecompress), - text: (data: Buffer) => data, -}; -type SupportedContentEncoding = keyof typeof DECODER_MAP; - -const isSupportedContentEncoding = ( - encoding: string -): encoding is SupportedContentEncoding => encoding in DECODER_MAP; +import { decompressBuffer } from "./compression"; /** * Handles the response from the upstream service and decodes the body if @@ -40,36 +26,45 @@ export const handleBlockingResponse: RawResponseBodyHandler = async ( let chunks: Buffer[] = []; proxyRes.on("data", (chunk) => chunks.push(chunk)); proxyRes.on("end", async () => { + const contentEncoding = proxyRes.headers["content-encoding"]; + const contentType = proxyRes.headers["content-type"]; let body: string | Buffer = Buffer.concat(chunks); const rejectWithMessage = function (msg: string, err: Error) { const error = `${msg} (${err.message})`; - req.log.warn({ stack: err.stack }, error); + req.log.warn( + { msg: error, stack: err.stack }, + "Error in blocking response handler" + ); sendProxyError(req, res, 500, "Internal Server Error", { error }); return reject(error); }; - const contentEncoding = proxyRes.headers["content-encoding"] ?? "text"; - if (isSupportedContentEncoding(contentEncoding)) { - try { - body = (await DECODER_MAP[contentEncoding](body)).toString(); - } catch (e) { - return rejectWithMessage(`Could not decode response body`, e); - } - } else { - return rejectWithMessage( - "API responded with unsupported content encoding", - new Error(`Unsupported content-encoding: ${contentEncoding}`) - ); + try { + body = await decompressBuffer(body, contentEncoding); + } catch (e) { + return rejectWithMessage(`Could not decode response body`, e); } try { - if (proxyRes.headers["content-type"]?.includes("application/json")) { - return resolve(JSON.parse(body)); - } - return resolve(body); + return resolve(tryParseAsJson(body, contentType)); } catch (e) { return rejectWithMessage("API responded with invalid JSON", e); } }); }); }; + +function tryParseAsJson(body: string, contentType?: string) { + // If the response is declared as JSON, it must parse or we will throw + if (contentType?.includes("application/json")) { + return JSON.parse(body); + } + // If it's not declared as JSON, some APIs we'll try to parse it as JSON + // anyway since some APIs return the wrong content-type header in some cases. + // If it fails to parse, we'll just return the raw body without throwing. + try { + return JSON.parse(body); + } catch (e) { + return body; + } +} diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index 5394b98..2b790f5 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -1,6 +1,5 @@ import express from "express"; import { pipeline, Readable, Transform } from "stream"; -import StreamArray from "stream-json/streamers/StreamArray"; import { StringDecoder } from "string_decoder"; import { promisify } from "util"; import type { logger } from "../../../logger"; @@ -18,6 +17,7 @@ import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder"; import { EventAggregator } from "./streaming/event-aggregator"; import { SSEMessageTransformer } from "./streaming/sse-message-transformer"; import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; +import { getStreamDecompressor } from "./compression"; const pipelineAsync = promisify(pipeline); @@ -41,21 +41,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( req, res ) => { - const { hash } = req.key!; + const { headers, statusCode } = proxyRes; if (!req.isStreaming) { throw new Error("handleStreamedResponse called for non-streaming request."); } - if (proxyRes.statusCode! > 201) { + if (statusCode! > 201) { req.isStreaming = false; req.log.warn( - { statusCode: proxyRes.statusCode, key: hash }, + { statusCode }, `Streaming request returned error status code. Falling back to non-streaming response handler.` ); return handleBlockingResponse(proxyRes, req, res); } - req.log.debug({ headers: proxyRes.headers }, `Starting to proxy SSE stream.`); + req.log.debug({ headers }, `Starting to proxy SSE stream.`); // Typically, streaming will have already been initialized by the request // queue to send heartbeat pings. @@ -66,7 +66,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( const prefersNativeEvents = req.inboundApi === req.outboundApi; const streamOptions = { - contentType: proxyRes.headers["content-type"], + contentType: headers["content-type"], api: req.outboundApi, logger: req.log, }; @@ -78,11 +78,10 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( // only have to write one aggregator (OpenAI input) for each output format. const aggregator = new EventAggregator(req); - // Decoder reads from the raw response buffer and produces a stream of - // discrete events in some format (text/event-stream, vnd.amazon.event-stream, - // streaming JSON, etc). + const decompressor = getStreamDecompressor(headers["content-encoding"]); + // Decoder reads from the response bytes to produce a stream of plaintext. const decoder = getDecoder({ ...streamOptions, input: proxyRes }); - // Adapter consumes the decoded events and produces server-sent events so we + // Adapter consumes the decoded text and produces server-sent events so we // have a standard event format for the client and to translate between API // message formats. const adapter = new SSEStreamAdapter(streamOptions); @@ -107,7 +106,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( try { await Promise.race([ handleAbortedStream(req, res), - pipelineAsync(proxyRes, decoder, adapter, transformer), + pipelineAsync(proxyRes, decompressor, decoder, adapter, transformer), ]); req.log.debug(`Finished proxying SSE stream.`); res.end(); @@ -180,8 +179,7 @@ function getDecoder(options: { } else if (contentType?.includes("application/json")) { throw new Error("JSON streaming not supported, request SSE instead"); } else { - // Passthrough stream, but ensures split chunks across multi-byte characters - // are handled correctly. + // Ensures split chunks across multi-byte characters are handled correctly. const stringDecoder = new StringDecoder("utf8"); return new Transform({ readableObjectMode: true, diff --git a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts index ce27355..8c24458 100644 --- a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts +++ b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts @@ -2,7 +2,6 @@ import pino from "pino"; import { Transform, TransformOptions } from "stream"; import { Message } from "@smithy/eventstream-codec"; import { APIFormat } from "../../../../shared/key-management"; -import { buildSpoofedSSE } from "../error-generator"; import { BadRequestError, RetryableError } from "../../../../shared/errors"; type SSEStreamAdapterOptions = TransformOptions & { @@ -108,34 +107,6 @@ export class SSEStreamAdapter extends Transform { } } - /** Processes an incoming array element from the Google AI JSON stream. */ - protected processGoogleObject(data: any): string | null { - // Sometimes data has fields key and value, sometimes it's just the - // candidates array. - const candidates = data.value?.candidates ?? data.candidates ?? [{}]; - try { - const hasParts = candidates[0].content?.parts?.length > 0; - if (hasParts) { - return `data: ${JSON.stringify(data.value ?? data)}`; - } else { - this.log.error({ event: data }, "Received bad Google AI event"); - return `data: ${buildSpoofedSSE({ - format: "google-ai", - title: "Proxy stream error", - message: - "The proxy received malformed or unexpected data from Google AI while streaming.", - obj: data, - reqId: "proxy-sse-adapter-message", - model: "", - })}`; - } - } catch (error) { - error.lastEvent = data; - this.emit("error", error); - } - return null; - } - _transform(data: any, _enc: string, callback: (err?: Error | null) => void) { try { if (this.isAwsStream) { diff --git a/src/shared/key-management/gcp/checker.ts b/src/shared/key-management/gcp/checker.ts index a32cfa9..baf0fd4 100644 --- a/src/shared/key-management/gcp/checker.ts +++ b/src/shared/key-management/gcp/checker.ts @@ -1,9 +1,9 @@ import { AxiosError } from "axios"; -import crypto from "crypto"; import { GcpModelFamily } from "../../models"; import { getAxiosInstance } from "../../network"; import { KeyCheckerBase } from "../key-checker-base"; import type { GcpKey, GcpKeyProvider } from "./provider"; +import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "./oauth"; const axios = getAxiosInstance(); @@ -37,6 +37,7 @@ export class GcpKeyChecker extends KeyCheckerBase { let checks: Promise[] = []; const isInitialCheck = !key.lastChecked; if (isInitialCheck) { + await this.maybeRefreshAccessToken(key); checks = [ this.invokeModel("claude-3-haiku@20240307", key, true), this.invokeModel("claude-3-sonnet@20240229", key, true), @@ -70,6 +71,7 @@ export class GcpKeyChecker extends KeyCheckerBase { modelFamilies: families, }); } else { + await this.maybeRefreshAccessToken(key); if (key.haikuEnabled) { await this.invokeModel("claude-3-haiku@20240307", key, false); } else if (key.sonnetEnabled) { @@ -85,10 +87,7 @@ export class GcpKeyChecker extends KeyCheckerBase { } this.log.info( - { - key: key.hash, - families: key.modelFamilies, - }, + { key: key.hash, families: key.modelFamilies }, "Checked key." ); } @@ -129,26 +128,36 @@ export class GcpKeyChecker extends KeyCheckerBase { this.updateKey(key.hash, { lastChecked: next }); } + private async maybeRefreshAccessToken(key: GcpKey) { + if (key.accessToken && key.accessTokenExpiresAt >= Date.now()) { + return; + } + + this.log.info({ key: key.hash }, "Refreshing GCP access token..."); + const [token, durationSec] = await refreshGcpAccessToken(key); + this.updateKey(key.hash, { + accessToken: token, + accessTokenExpiresAt: Date.now() + durationSec * 1000 * 0.95, + }); + } + /** * Attempt to invoke the given model with the given key. Returns true if the * key has access to the model, false if it does not. Throws an error if the * key is disabled. */ private async invokeModel(model: string, key: GcpKey, initial: boolean) { - const creds = GcpKeyChecker.getCredentialsFromKey(key); - const signedJWT = await GcpKeyChecker.createSignedJWT( - creds.clientEmail, - creds.privateKey - ); - const [accessToken, jwtError] = - await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT); - if (accessToken === null) { - this.log.warn( - { key: key.hash, jwtError }, - "Unable to get the access token" + const creds = await getCredentialsFromGcpKey(key); + try { + await this.maybeRefreshAccessToken(key); + } catch (e) { + this.log.error( + { key: key.hash, error: e.message }, + "Could not test key due to error while getting access token." ); return false; } + const payload = { max_tokens: 1, messages: TEST_MESSAGES, @@ -158,7 +167,7 @@ export class GcpKeyChecker extends KeyCheckerBase { POST_STREAM_RAW_URL(creds.projectId, creds.region, model), payload, { - headers: GcpKeyChecker.getRequestHeaders(accessToken), + headers: GcpKeyChecker.getRequestHeaders(key.accessToken), validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300, @@ -184,114 +193,10 @@ export class GcpKeyChecker extends KeyCheckerBase { } } - static async createSignedJWT(email: string, pkey: string): Promise { - let cryptoKey = await crypto.subtle.importKey( - "pkcs8", - GcpKeyChecker.str2ab(atob(pkey)), - { name: "RSASSA-PKCS1-v1_5", hash: { name: "SHA-256" } }, - false, - ["sign"] - ); - - const authUrl = "https://www.googleapis.com/oauth2/v4/token"; - const issued = Math.floor(Date.now() / 1000); - const expires = issued + 600; - - const header = { alg: "RS256", typ: "JWT" }; - - const payload = { - iss: email, - aud: authUrl, - iat: issued, - exp: expires, - scope: "https://www.googleapis.com/auth/cloud-platform", - }; - - const encodedHeader = GcpKeyChecker.urlSafeBase64Encode( - JSON.stringify(header) - ); - const encodedPayload = GcpKeyChecker.urlSafeBase64Encode( - JSON.stringify(payload) - ); - - const unsignedToken = `${encodedHeader}.${encodedPayload}`; - - const signature = await crypto.subtle.sign( - "RSASSA-PKCS1-v1_5", - cryptoKey, - GcpKeyChecker.str2ab(unsignedToken) - ); - - const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature); - return `${unsignedToken}.${encodedSignature}`; - } - - static async exchangeJwtForAccessToken( - signed_jwt: string - ): Promise<[string | null, string]> { - const auth_url = "https://www.googleapis.com/oauth2/v4/token"; - const params = { - grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer", - assertion: signed_jwt, - }; - - const r = await fetch(auth_url, { - method: "POST", - headers: { "Content-Type": "application/x-www-form-urlencoded" }, - body: Object.entries(params) - .map(([k, v]) => `${k}=${v}`) - .join("&"), - }).then((res) => res.json()); - - if (r.access_token) { - return [r.access_token, ""]; - } - - return [null, JSON.stringify(r)]; - } - - static str2ab(str: string): ArrayBuffer { - const buffer = new ArrayBuffer(str.length); - const bufferView = new Uint8Array(buffer); - for (let i = 0; i < str.length; i++) { - bufferView[i] = str.charCodeAt(i); - } - return buffer; - } - - static urlSafeBase64Encode(data: string | ArrayBuffer): string { - let base64: string; - if (typeof data === "string") { - base64 = btoa( - encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => - String.fromCharCode(parseInt("0x" + p1, 16)) - ) - ); - } else { - base64 = btoa(String.fromCharCode(...new Uint8Array(data))); - } - return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, ""); - } - static getRequestHeaders(accessToken: string) { return { Authorization: `Bearer ${accessToken}`, "Content-Type": "application/json", }; } - - static getCredentialsFromKey(key: GcpKey) { - const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":"); - if (!projectId || !clientEmail || !region || !rawPrivateKey) { - throw new Error("Invalid GCP key"); - } - const privateKey = rawPrivateKey - .replace( - /-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, - "" - ) - .trim(); - - return { projectId, clientEmail, region, privateKey }; - } } diff --git a/src/shared/key-management/gcp/oauth.ts b/src/shared/key-management/gcp/oauth.ts new file mode 100644 index 0000000..e01a535 --- /dev/null +++ b/src/shared/key-management/gcp/oauth.ts @@ -0,0 +1,150 @@ +import crypto from "crypto"; +import type { GcpKey } from "./provider"; +import { getAxiosInstance } from "../../network"; +import { logger } from "../../../logger"; + +const axios = getAxiosInstance(); +const log = logger.child({ module: "gcp-oauth" }); + +const authUrl = "https://www.googleapis.com/oauth2/v4/token"; +const scope = "https://www.googleapis.com/auth/cloud-platform"; + +type GoogleAuthResponse = { + access_token: string; + scope: string; + token_type: "Bearer"; + expires_in: number; +}; + +type GoogleAuthError = { + error: + | "unauthorized_client" + | "access_denied" + | "admin_policy_enforced" + | "invalid_client" + | "invalid_grant" + | "invalid_scope" + | "disabled_client" + | "org_internal"; + error_description: string; +}; + +export async function refreshGcpAccessToken( + key: GcpKey +): Promise<[string, number]> { + log.info({ key: key.hash }, "Entering GCP OAuth flow..."); + const { clientEmail, privateKey } = await getCredentialsFromGcpKey(key); + + // https://developers.google.com/identity/protocols/oauth2/service-account#authorizingrequests + const jwt = await createSignedJWT(clientEmail, privateKey); + log.info({ key: key.hash }, "Signed JWT, exchanging for access token..."); + const res = await axios.post( + authUrl, + { + grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer", + assertion: jwt, + }, + { + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + validateStatus: () => true, + } + ); + const status = res.status; + const headers = res.headers; + const data = res.data; + + if ("error" in data || status >= 400) { + log.error( + { key: key.hash, status, headers, data }, + "Error from Google Identity API while getting access token." + ); + throw new Error( + `Google Identity API returned error: ${(data as GoogleAuthError).error}` + ); + } + + log.info({ key: key.hash, exp: data.expires_in }, "Got access token."); + return [data.access_token, data.expires_in]; +} + +export async function getCredentialsFromGcpKey(key: GcpKey) { + const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":"); + if (!projectId || !clientEmail || !region || !rawPrivateKey) { + log.error( + { key: key.hash }, + "Cannot parse GCP credentials. Ensure they are in the format PROJECT_ID:CLIENT_EMAIL:REGION:PRIVATE_KEY, and ensure no whitespace or newlines are in the private key." + ); + throw new Error("Cannot parse GCP credentials."); + } + + if (!key.privateKey) { + await importPrivateKey(key, rawPrivateKey); + } + + return { projectId, clientEmail, region, privateKey: key.privateKey! }; +} + +async function createSignedJWT( + email: string, + pkey: crypto.webcrypto.CryptoKey +) { + const issued = Math.floor(Date.now() / 1000); + const expires = issued + 600; + + const header = { alg: "RS256", typ: "JWT" }; + + const payload = { + iss: email, + aud: authUrl, + iat: issued, + exp: expires, + scope, + }; + + const encodedHeader = urlSafeBase64Encode(JSON.stringify(header)); + const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload)); + + const unsignedToken = `${encodedHeader}.${encodedPayload}`; + + const signature = await crypto.subtle.sign( + "RSASSA-PKCS1-v1_5", + pkey, + new TextEncoder().encode(unsignedToken) + ); + + const encodedSignature = urlSafeBase64Encode(signature); + return `${unsignedToken}.${encodedSignature}`; +} + +async function importPrivateKey(key: GcpKey, rawPrivateKey: string) { + log.info({ key: key.hash }, "Importing GCP private key..."); + const privateKey = rawPrivateKey + .replace( + /-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, + "" + ) + .trim(); + const binaryKey = Buffer.from(privateKey, "base64"); + key.privateKey = await crypto.subtle.importKey( + "pkcs8", + binaryKey, + { name: "RSASSA-PKCS1-v1_5", hash: "SHA-256" }, + true, + ["sign"] + ); + log.info({ key: key.hash }, "GCP private key imported."); +} + +function urlSafeBase64Encode(data: string | ArrayBuffer): string { + let base64: string; + if (typeof data === "string") { + base64 = btoa( + encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => + String.fromCharCode(parseInt("0x" + p1, 16)) + ) + ); + } else { + base64 = btoa(String.fromCharCode(...new Uint8Array(data))); + } + return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, ""); +} diff --git a/src/shared/key-management/gcp/provider.ts b/src/shared/key-management/gcp/provider.ts index e3f72ef..52398c6 100644 --- a/src/shared/key-management/gcp/provider.ts +++ b/src/shared/key-management/gcp/provider.ts @@ -17,6 +17,11 @@ export interface GcpKey extends Key, GcpKeyUsage { sonnetEnabled: boolean; haikuEnabled: boolean; sonnet35Enabled: boolean; + + privateKey?: crypto.webcrypto.CryptoKey; + /** Cached access token for GCP APIs. */ + accessToken: string; + accessTokenExpiresAt: number; } /** @@ -68,6 +73,8 @@ export class GcpKeyProvider implements KeyProvider { sonnetEnabled: true, haikuEnabled: false, sonnet35Enabled: false, + accessToken: "", + accessTokenExpiresAt: 0, ["gcp-claudeTokens"]: 0, ["gcp-claude-opusTokens"]: 0, }; diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 9e041c4..ad6cb76 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -8,13 +8,13 @@ import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models"; import { Key, KeyProvider } from "./index"; import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; -import { GoogleAIKeyProvider } from "./google-ai/provider"; +import { GoogleAIKeyProvider } from "./google-ai/provider"; import { AwsBedrockKeyProvider } from "./aws/provider"; -import { GcpKeyProvider } from "./gcp/provider"; +import { GcpKeyProvider, GcpKey } from "./gcp/provider"; import { AzureOpenAIKeyProvider } from "./azure/provider"; import { MistralAIKeyProvider } from "./mistral-ai/provider"; -type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; +type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate | Partial; export class KeyPool { private keyProviders: KeyProvider[] = [];