finally DOES something about broken GCP streaming, boebeitfully

This commit is contained in:
nai-degen 2024-10-12 20:10:59 -05:00
parent 13aa55cd3d
commit 0c6ec3254f
10 changed files with 278 additions and 350 deletions

View File

@ -1,11 +1,8 @@
import { Request } from "express";
import crypto from "crypto";
import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas";
import { keyPool } from "../../../../shared/key-management";
import { getAxiosInstance } from "../../../../shared/network";
import { GcpKey, keyPool } from "../../../../shared/key-management";
import { ProxyReqMutator } from "../index";
const axios = getAxiosInstance();
import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "../../../../shared/key-management/gcp/oauth";
import { credential } from "firebase-admin";
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
@ -21,9 +18,18 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
}
const { model } = req.body;
const key = keyPool.get(model, "gcp");
const key: GcpKey = keyPool.get(model, "gcp") as GcpKey;
manager.setKey(key);
if (!key.accessToken || Date.now() > key.accessTokenExpiresAt) {
const [token, durationSec] = await refreshGcpAccessToken(key);
keyPool.update(key, {
accessToken: token,
accessTokenExpiresAt: Date.now() + durationSec * 1000 * 0.95,
} as GcpKey);
}
req.log.info({ key: key.hash, model }, "Assigned GCP key to request");
// TODO: This should happen in transform-outbound-payload.ts
@ -43,7 +49,7 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
.parse(req.body);
strippedParams.anthropic_version = "vertex-2023-10-16";
const [accessToken, credential] = await getAccessToken(req);
const credential = await getCredentialsFromGcpKey(key);
const host = GCP_HOST.replace("%REGION%", credential.region);
// GCP doesn't use the anthropic-version header, but we set it to ensure the
@ -58,151 +64,8 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
headers: {
["host"]: host,
["content-type"]: "application/json",
["authorization"]: `Bearer ${accessToken}`,
["authorization"]: `Bearer ${key.accessToken}`,
},
body: JSON.stringify(strippedParams),
});
};
async function getAccessToken(
req: Readonly<Request>
): Promise<[string, Credential]> {
// TODO: access token caching to reduce latency
const credential = getCredentialParts(req);
const signedJWT = await createSignedJWT(
credential.clientEmail,
credential.privateKey
);
const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT);
if (accessToken === null) {
req.log.warn(
{ key: req.key!.hash, jwtError },
"Unable to get the access token"
);
throw new Error("The access token is invalid.");
}
return [accessToken, credential];
}
async function createSignedJWT(email: string, pkey: string): Promise<string> {
let cryptoKey = await crypto.subtle.importKey(
"pkcs8",
str2ab(atob(pkey)),
{
name: "RSASSA-PKCS1-v1_5",
hash: { name: "SHA-256" },
},
false,
["sign"]
);
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = {
alg: "RS256",
typ: "JWT",
};
const payload = {
iss: email,
aud: authUrl,
iat: issued,
exp: expires,
scope: "https://www.googleapis.com/auth/cloud-platform",
};
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
const signature = await crypto.subtle.sign(
"RSASSA-PKCS1-v1_5",
cryptoKey,
str2ab(unsignedToken)
);
const encodedSignature = urlSafeBase64Encode(signature);
return `${unsignedToken}.${encodedSignature}`;
}
async function exchangeJwtForAccessToken(
signedJwt: string
): Promise<[string | null, string]> {
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: signedJwt,
};
try {
const response = await axios.post(authUrl, params, {
headers: { "Content-Type": "application/x-www-form-urlencoded" },
});
if (response.data.access_token) {
return [response.data.access_token, ""];
} else {
return [null, JSON.stringify(response.data)];
}
} catch (error) {
if ("response" in error && "data" in error.response) {
return [null, JSON.stringify(error.response.data)];
} else {
return [null, "An unexpected error occurred"];
}
}
}
function str2ab(str: string): ArrayBuffer {
const buffer = new ArrayBuffer(str.length);
const bufferView = new Uint8Array(buffer);
for (let i = 0; i < str.length; i++) {
bufferView[i] = str.charCodeAt(i);
}
return buffer;
}
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
}
type Credential = {
projectId: string;
clientEmail: string;
region: string;
privateKey: string;
};
function getCredentialParts(req: Readonly<Request>): Credential {
const [projectId, clientEmail, region, rawPrivateKey] =
req.key!.key.split(":");
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
req.log.error(
{ key: req.key!.hash },
"GCP_CREDENTIALS isn't correctly formatted; refer to the docs"
);
throw new Error("The key assigned to this request is invalid.");
}
const privateKey = rawPrivateKey
.replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.trim();
return { projectId, clientEmail, region, privateKey };
}

View File

@ -125,6 +125,9 @@ function pinoLoggerPlugin(proxyServer: ProxyServer<Request>) {
target: `${protocol}//${host}${path}`,
status: proxyRes.statusCode,
contentType: proxyRes.headers["content-type"],
contentEncoding: proxyRes.headers["content-encoding"],
contentLength: proxyRes.headers["content-length"],
transferEncoding: proxyRes.headers["transfer-encoding"],
},
"Got response from upstream API."
);

View File

@ -0,0 +1,36 @@
import util from "util";
import zlib from "zlib";
import { PassThrough } from "stream";
const BUFFER_DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
br: util.promisify(zlib.brotliDecompress),
text: (data: Buffer) => data,
};
const STREAM_DECODER_MAP = {
gzip: zlib.createGunzip,
deflate: zlib.createInflate,
br: zlib.createBrotliDecompress,
text: () => new PassThrough(),
};
type SupportedContentEncoding = keyof typeof BUFFER_DECODER_MAP;
const isSupportedContentEncoding = (
encoding: string
): encoding is SupportedContentEncoding => encoding in BUFFER_DECODER_MAP;
export async function decompressBuffer(buf: Buffer, encoding: string = "text") {
if (isSupportedContentEncoding(encoding)) {
return (await BUFFER_DECODER_MAP[encoding](buf)).toString();
}
throw new Error(`Unsupported content-encoding: ${encoding}`);
}
export function getStreamDecompressor(encoding: string = "text") {
if (isSupportedContentEncoding(encoding)) {
return STREAM_DECODER_MAP[encoding]();
}
throw new Error(`Unsupported content-encoding: ${encoding}`);
}

View File

@ -1,20 +1,6 @@
import { Request, Response } from "express";
import util from "util";
import zlib from "zlib";
import { sendProxyError } from "../common";
import type { RawResponseBodyHandler } from "./index";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
br: util.promisify(zlib.brotliDecompress),
text: (data: Buffer) => data,
};
type SupportedContentEncoding = keyof typeof DECODER_MAP;
const isSupportedContentEncoding = (
encoding: string
): encoding is SupportedContentEncoding => encoding in DECODER_MAP;
import { decompressBuffer } from "./compression";
/**
* Handles the response from the upstream service and decodes the body if
@ -40,36 +26,45 @@ export const handleBlockingResponse: RawResponseBodyHandler = async (
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
const contentEncoding = proxyRes.headers["content-encoding"];
const contentType = proxyRes.headers["content-type"];
let body: string | Buffer = Buffer.concat(chunks);
const rejectWithMessage = function (msg: string, err: Error) {
const error = `${msg} (${err.message})`;
req.log.warn({ stack: err.stack }, error);
req.log.warn(
{ msg: error, stack: err.stack },
"Error in blocking response handler"
);
sendProxyError(req, res, 500, "Internal Server Error", { error });
return reject(error);
};
const contentEncoding = proxyRes.headers["content-encoding"] ?? "text";
if (isSupportedContentEncoding(contentEncoding)) {
try {
body = (await DECODER_MAP[contentEncoding](body)).toString();
} catch (e) {
return rejectWithMessage(`Could not decode response body`, e);
}
} else {
return rejectWithMessage(
"API responded with unsupported content encoding",
new Error(`Unsupported content-encoding: ${contentEncoding}`)
);
try {
body = await decompressBuffer(body, contentEncoding);
} catch (e) {
return rejectWithMessage(`Could not decode response body`, e);
}
try {
if (proxyRes.headers["content-type"]?.includes("application/json")) {
return resolve(JSON.parse(body));
}
return resolve(body);
return resolve(tryParseAsJson(body, contentType));
} catch (e) {
return rejectWithMessage("API responded with invalid JSON", e);
}
});
});
};
function tryParseAsJson(body: string, contentType?: string) {
// If the response is declared as JSON, it must parse or we will throw
if (contentType?.includes("application/json")) {
return JSON.parse(body);
}
// If it's not declared as JSON, some APIs we'll try to parse it as JSON
// anyway since some APIs return the wrong content-type header in some cases.
// If it fails to parse, we'll just return the raw body without throwing.
try {
return JSON.parse(body);
} catch (e) {
return body;
}
}

View File

@ -1,6 +1,5 @@
import express from "express";
import { pipeline, Readable, Transform } from "stream";
import StreamArray from "stream-json/streamers/StreamArray";
import { StringDecoder } from "string_decoder";
import { promisify } from "util";
import type { logger } from "../../../logger";
@ -18,6 +17,7 @@ import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder";
import { EventAggregator } from "./streaming/event-aggregator";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { getStreamDecompressor } from "./compression";
const pipelineAsync = promisify(pipeline);
@ -41,21 +41,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
req,
res
) => {
const { hash } = req.key!;
const { headers, statusCode } = proxyRes;
if (!req.isStreaming) {
throw new Error("handleStreamedResponse called for non-streaming request.");
}
if (proxyRes.statusCode! > 201) {
if (statusCode! > 201) {
req.isStreaming = false;
req.log.warn(
{ statusCode: proxyRes.statusCode, key: hash },
{ statusCode },
`Streaming request returned error status code. Falling back to non-streaming response handler.`
);
return handleBlockingResponse(proxyRes, req, res);
}
req.log.debug({ headers: proxyRes.headers }, `Starting to proxy SSE stream.`);
req.log.debug({ headers }, `Starting to proxy SSE stream.`);
// Typically, streaming will have already been initialized by the request
// queue to send heartbeat pings.
@ -66,7 +66,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
const prefersNativeEvents = req.inboundApi === req.outboundApi;
const streamOptions = {
contentType: proxyRes.headers["content-type"],
contentType: headers["content-type"],
api: req.outboundApi,
logger: req.log,
};
@ -78,11 +78,10 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
// only have to write one aggregator (OpenAI input) for each output format.
const aggregator = new EventAggregator(req);
// Decoder reads from the raw response buffer and produces a stream of
// discrete events in some format (text/event-stream, vnd.amazon.event-stream,
// streaming JSON, etc).
const decompressor = getStreamDecompressor(headers["content-encoding"]);
// Decoder reads from the response bytes to produce a stream of plaintext.
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
// Adapter consumes the decoded events and produces server-sent events so we
// Adapter consumes the decoded text and produces server-sent events so we
// have a standard event format for the client and to translate between API
// message formats.
const adapter = new SSEStreamAdapter(streamOptions);
@ -107,7 +106,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
try {
await Promise.race([
handleAbortedStream(req, res),
pipelineAsync(proxyRes, decoder, adapter, transformer),
pipelineAsync(proxyRes, decompressor, decoder, adapter, transformer),
]);
req.log.debug(`Finished proxying SSE stream.`);
res.end();
@ -180,8 +179,7 @@ function getDecoder(options: {
} else if (contentType?.includes("application/json")) {
throw new Error("JSON streaming not supported, request SSE instead");
} else {
// Passthrough stream, but ensures split chunks across multi-byte characters
// are handled correctly.
// Ensures split chunks across multi-byte characters are handled correctly.
const stringDecoder = new StringDecoder("utf8");
return new Transform({
readableObjectMode: true,

View File

@ -2,7 +2,6 @@ import pino from "pino";
import { Transform, TransformOptions } from "stream";
import { Message } from "@smithy/eventstream-codec";
import { APIFormat } from "../../../../shared/key-management";
import { buildSpoofedSSE } from "../error-generator";
import { BadRequestError, RetryableError } from "../../../../shared/errors";
type SSEStreamAdapterOptions = TransformOptions & {
@ -108,34 +107,6 @@ export class SSEStreamAdapter extends Transform {
}
}
/** Processes an incoming array element from the Google AI JSON stream. */
protected processGoogleObject(data: any): string | null {
// Sometimes data has fields key and value, sometimes it's just the
// candidates array.
const candidates = data.value?.candidates ?? data.candidates ?? [{}];
try {
const hasParts = candidates[0].content?.parts?.length > 0;
if (hasParts) {
return `data: ${JSON.stringify(data.value ?? data)}`;
} else {
this.log.error({ event: data }, "Received bad Google AI event");
return `data: ${buildSpoofedSSE({
format: "google-ai",
title: "Proxy stream error",
message:
"The proxy received malformed or unexpected data from Google AI while streaming.",
obj: data,
reqId: "proxy-sse-adapter-message",
model: "",
})}`;
}
} catch (error) {
error.lastEvent = data;
this.emit("error", error);
}
return null;
}
_transform(data: any, _enc: string, callback: (err?: Error | null) => void) {
try {
if (this.isAwsStream) {

View File

@ -1,9 +1,9 @@
import { AxiosError } from "axios";
import crypto from "crypto";
import { GcpModelFamily } from "../../models";
import { getAxiosInstance } from "../../network";
import { KeyCheckerBase } from "../key-checker-base";
import type { GcpKey, GcpKeyProvider } from "./provider";
import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "./oauth";
const axios = getAxiosInstance();
@ -37,6 +37,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
let checks: Promise<boolean>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
await this.maybeRefreshAccessToken(key);
checks = [
this.invokeModel("claude-3-haiku@20240307", key, true),
this.invokeModel("claude-3-sonnet@20240229", key, true),
@ -70,6 +71,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
modelFamilies: families,
});
} else {
await this.maybeRefreshAccessToken(key);
if (key.haikuEnabled) {
await this.invokeModel("claude-3-haiku@20240307", key, false);
} else if (key.sonnetEnabled) {
@ -85,10 +87,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
}
this.log.info(
{
key: key.hash,
families: key.modelFamilies,
},
{ key: key.hash, families: key.modelFamilies },
"Checked key."
);
}
@ -129,26 +128,36 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
this.updateKey(key.hash, { lastChecked: next });
}
private async maybeRefreshAccessToken(key: GcpKey) {
if (key.accessToken && key.accessTokenExpiresAt >= Date.now()) {
return;
}
this.log.info({ key: key.hash }, "Refreshing GCP access token...");
const [token, durationSec] = await refreshGcpAccessToken(key);
this.updateKey(key.hash, {
accessToken: token,
accessTokenExpiresAt: Date.now() + durationSec * 1000 * 0.95,
});
}
/**
* Attempt to invoke the given model with the given key. Returns true if the
* key has access to the model, false if it does not. Throws an error if the
* key is disabled.
*/
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
const creds = GcpKeyChecker.getCredentialsFromKey(key);
const signedJWT = await GcpKeyChecker.createSignedJWT(
creds.clientEmail,
creds.privateKey
);
const [accessToken, jwtError] =
await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT);
if (accessToken === null) {
this.log.warn(
{ key: key.hash, jwtError },
"Unable to get the access token"
const creds = await getCredentialsFromGcpKey(key);
try {
await this.maybeRefreshAccessToken(key);
} catch (e) {
this.log.error(
{ key: key.hash, error: e.message },
"Could not test key due to error while getting access token."
);
return false;
}
const payload = {
max_tokens: 1,
messages: TEST_MESSAGES,
@ -158,7 +167,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
payload,
{
headers: GcpKeyChecker.getRequestHeaders(accessToken),
headers: GcpKeyChecker.getRequestHeaders(key.accessToken),
validateStatus: initial
? () => true
: (status: number) => status >= 200 && status < 300,
@ -184,114 +193,10 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
}
}
static async createSignedJWT(email: string, pkey: string): Promise<string> {
let cryptoKey = await crypto.subtle.importKey(
"pkcs8",
GcpKeyChecker.str2ab(atob(pkey)),
{ name: "RSASSA-PKCS1-v1_5", hash: { name: "SHA-256" } },
false,
["sign"]
);
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = { alg: "RS256", typ: "JWT" };
const payload = {
iss: email,
aud: authUrl,
iat: issued,
exp: expires,
scope: "https://www.googleapis.com/auth/cloud-platform",
};
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(header)
);
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(payload)
);
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
const signature = await crypto.subtle.sign(
"RSASSA-PKCS1-v1_5",
cryptoKey,
GcpKeyChecker.str2ab(unsignedToken)
);
const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature);
return `${unsignedToken}.${encodedSignature}`;
}
static async exchangeJwtForAccessToken(
signed_jwt: string
): Promise<[string | null, string]> {
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: signed_jwt,
};
const r = await fetch(auth_url, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: Object.entries(params)
.map(([k, v]) => `${k}=${v}`)
.join("&"),
}).then((res) => res.json());
if (r.access_token) {
return [r.access_token, ""];
}
return [null, JSON.stringify(r)];
}
static str2ab(str: string): ArrayBuffer {
const buffer = new ArrayBuffer(str.length);
const bufferView = new Uint8Array(buffer);
for (let i = 0; i < str.length; i++) {
bufferView[i] = str.charCodeAt(i);
}
return buffer;
}
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
}
static getRequestHeaders(accessToken: string) {
return {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
};
}
static getCredentialsFromKey(key: GcpKey) {
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
throw new Error("Invalid GCP key");
}
const privateKey = rawPrivateKey
.replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.trim();
return { projectId, clientEmail, region, privateKey };
}
}

View File

@ -0,0 +1,150 @@
import crypto from "crypto";
import type { GcpKey } from "./provider";
import { getAxiosInstance } from "../../network";
import { logger } from "../../../logger";
const axios = getAxiosInstance();
const log = logger.child({ module: "gcp-oauth" });
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const scope = "https://www.googleapis.com/auth/cloud-platform";
type GoogleAuthResponse = {
access_token: string;
scope: string;
token_type: "Bearer";
expires_in: number;
};
type GoogleAuthError = {
error:
| "unauthorized_client"
| "access_denied"
| "admin_policy_enforced"
| "invalid_client"
| "invalid_grant"
| "invalid_scope"
| "disabled_client"
| "org_internal";
error_description: string;
};
export async function refreshGcpAccessToken(
key: GcpKey
): Promise<[string, number]> {
log.info({ key: key.hash }, "Entering GCP OAuth flow...");
const { clientEmail, privateKey } = await getCredentialsFromGcpKey(key);
// https://developers.google.com/identity/protocols/oauth2/service-account#authorizingrequests
const jwt = await createSignedJWT(clientEmail, privateKey);
log.info({ key: key.hash }, "Signed JWT, exchanging for access token...");
const res = await axios.post<GoogleAuthResponse | GoogleAuthError>(
authUrl,
{
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: jwt,
},
{
headers: { "Content-Type": "application/x-www-form-urlencoded" },
validateStatus: () => true,
}
);
const status = res.status;
const headers = res.headers;
const data = res.data;
if ("error" in data || status >= 400) {
log.error(
{ key: key.hash, status, headers, data },
"Error from Google Identity API while getting access token."
);
throw new Error(
`Google Identity API returned error: ${(data as GoogleAuthError).error}`
);
}
log.info({ key: key.hash, exp: data.expires_in }, "Got access token.");
return [data.access_token, data.expires_in];
}
export async function getCredentialsFromGcpKey(key: GcpKey) {
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
log.error(
{ key: key.hash },
"Cannot parse GCP credentials. Ensure they are in the format PROJECT_ID:CLIENT_EMAIL:REGION:PRIVATE_KEY, and ensure no whitespace or newlines are in the private key."
);
throw new Error("Cannot parse GCP credentials.");
}
if (!key.privateKey) {
await importPrivateKey(key, rawPrivateKey);
}
return { projectId, clientEmail, region, privateKey: key.privateKey! };
}
async function createSignedJWT(
email: string,
pkey: crypto.webcrypto.CryptoKey
) {
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = { alg: "RS256", typ: "JWT" };
const payload = {
iss: email,
aud: authUrl,
iat: issued,
exp: expires,
scope,
};
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
const signature = await crypto.subtle.sign(
"RSASSA-PKCS1-v1_5",
pkey,
new TextEncoder().encode(unsignedToken)
);
const encodedSignature = urlSafeBase64Encode(signature);
return `${unsignedToken}.${encodedSignature}`;
}
async function importPrivateKey(key: GcpKey, rawPrivateKey: string) {
log.info({ key: key.hash }, "Importing GCP private key...");
const privateKey = rawPrivateKey
.replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.trim();
const binaryKey = Buffer.from(privateKey, "base64");
key.privateKey = await crypto.subtle.importKey(
"pkcs8",
binaryKey,
{ name: "RSASSA-PKCS1-v1_5", hash: "SHA-256" },
true,
["sign"]
);
log.info({ key: key.hash }, "GCP private key imported.");
}
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
}

View File

@ -17,6 +17,11 @@ export interface GcpKey extends Key, GcpKeyUsage {
sonnetEnabled: boolean;
haikuEnabled: boolean;
sonnet35Enabled: boolean;
privateKey?: crypto.webcrypto.CryptoKey;
/** Cached access token for GCP APIs. */
accessToken: string;
accessTokenExpiresAt: number;
}
/**
@ -68,6 +73,8 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
sonnetEnabled: true,
haikuEnabled: false,
sonnet35Enabled: false,
accessToken: "",
accessTokenExpiresAt: 0,
["gcp-claudeTokens"]: 0,
["gcp-claude-opusTokens"]: 0,
};

View File

@ -8,13 +8,13 @@ import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models";
import { Key, KeyProvider } from "./index";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GoogleAIKeyProvider } from "./google-ai/provider";
import { GoogleAIKeyProvider } from "./google-ai/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { GcpKeyProvider } from "./gcp/provider";
import { GcpKeyProvider, GcpKey } from "./gcp/provider";
import { AzureOpenAIKeyProvider } from "./azure/provider";
import { MistralAIKeyProvider } from "./mistral-ai/provider";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate | Partial<GcpKey>;
export class KeyPool {
private keyProviders: KeyProvider[] = [];