finally DOES something about broken GCP streaming, boebeitfully
This commit is contained in:
parent
13aa55cd3d
commit
0c6ec3254f
|
@ -1,11 +1,8 @@
|
|||
import { Request } from "express";
|
||||
import crypto from "crypto";
|
||||
import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas";
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { getAxiosInstance } from "../../../../shared/network";
|
||||
import { GcpKey, keyPool } from "../../../../shared/key-management";
|
||||
import { ProxyReqMutator } from "../index";
|
||||
|
||||
const axios = getAxiosInstance();
|
||||
import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "../../../../shared/key-management/gcp/oauth";
|
||||
import { credential } from "firebase-admin";
|
||||
|
||||
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
||||
|
||||
|
@ -21,9 +18,18 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
|
|||
}
|
||||
|
||||
const { model } = req.body;
|
||||
const key = keyPool.get(model, "gcp");
|
||||
const key: GcpKey = keyPool.get(model, "gcp") as GcpKey;
|
||||
manager.setKey(key);
|
||||
|
||||
if (!key.accessToken || Date.now() > key.accessTokenExpiresAt) {
|
||||
const [token, durationSec] = await refreshGcpAccessToken(key);
|
||||
keyPool.update(key, {
|
||||
accessToken: token,
|
||||
accessTokenExpiresAt: Date.now() + durationSec * 1000 * 0.95,
|
||||
} as GcpKey);
|
||||
}
|
||||
|
||||
|
||||
req.log.info({ key: key.hash, model }, "Assigned GCP key to request");
|
||||
|
||||
// TODO: This should happen in transform-outbound-payload.ts
|
||||
|
@ -43,7 +49,7 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
|
|||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "vertex-2023-10-16";
|
||||
|
||||
const [accessToken, credential] = await getAccessToken(req);
|
||||
const credential = await getCredentialsFromGcpKey(key);
|
||||
|
||||
const host = GCP_HOST.replace("%REGION%", credential.region);
|
||||
// GCP doesn't use the anthropic-version header, but we set it to ensure the
|
||||
|
@ -58,151 +64,8 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
|
|||
headers: {
|
||||
["host"]: host,
|
||||
["content-type"]: "application/json",
|
||||
["authorization"]: `Bearer ${accessToken}`,
|
||||
["authorization"]: `Bearer ${key.accessToken}`,
|
||||
},
|
||||
body: JSON.stringify(strippedParams),
|
||||
});
|
||||
};
|
||||
|
||||
async function getAccessToken(
|
||||
req: Readonly<Request>
|
||||
): Promise<[string, Credential]> {
|
||||
// TODO: access token caching to reduce latency
|
||||
const credential = getCredentialParts(req);
|
||||
const signedJWT = await createSignedJWT(
|
||||
credential.clientEmail,
|
||||
credential.privateKey
|
||||
);
|
||||
const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT);
|
||||
if (accessToken === null) {
|
||||
req.log.warn(
|
||||
{ key: req.key!.hash, jwtError },
|
||||
"Unable to get the access token"
|
||||
);
|
||||
throw new Error("The access token is invalid.");
|
||||
}
|
||||
return [accessToken, credential];
|
||||
}
|
||||
|
||||
async function createSignedJWT(email: string, pkey: string): Promise<string> {
|
||||
let cryptoKey = await crypto.subtle.importKey(
|
||||
"pkcs8",
|
||||
str2ab(atob(pkey)),
|
||||
{
|
||||
name: "RSASSA-PKCS1-v1_5",
|
||||
hash: { name: "SHA-256" },
|
||||
},
|
||||
false,
|
||||
["sign"]
|
||||
);
|
||||
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const issued = Math.floor(Date.now() / 1000);
|
||||
const expires = issued + 600;
|
||||
|
||||
const header = {
|
||||
alg: "RS256",
|
||||
typ: "JWT",
|
||||
};
|
||||
|
||||
const payload = {
|
||||
iss: email,
|
||||
aud: authUrl,
|
||||
iat: issued,
|
||||
exp: expires,
|
||||
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||
};
|
||||
|
||||
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
|
||||
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
|
||||
|
||||
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||
|
||||
const signature = await crypto.subtle.sign(
|
||||
"RSASSA-PKCS1-v1_5",
|
||||
cryptoKey,
|
||||
str2ab(unsignedToken)
|
||||
);
|
||||
|
||||
const encodedSignature = urlSafeBase64Encode(signature);
|
||||
return `${unsignedToken}.${encodedSignature}`;
|
||||
}
|
||||
|
||||
async function exchangeJwtForAccessToken(
|
||||
signedJwt: string
|
||||
): Promise<[string | null, string]> {
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const params = {
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
assertion: signedJwt,
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await axios.post(authUrl, params, {
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
});
|
||||
|
||||
if (response.data.access_token) {
|
||||
return [response.data.access_token, ""];
|
||||
} else {
|
||||
return [null, JSON.stringify(response.data)];
|
||||
}
|
||||
} catch (error) {
|
||||
if ("response" in error && "data" in error.response) {
|
||||
return [null, JSON.stringify(error.response.data)];
|
||||
} else {
|
||||
return [null, "An unexpected error occurred"];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function str2ab(str: string): ArrayBuffer {
|
||||
const buffer = new ArrayBuffer(str.length);
|
||||
const bufferView = new Uint8Array(buffer);
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
bufferView[i] = str.charCodeAt(i);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||
let base64: string;
|
||||
if (typeof data === "string") {
|
||||
base64 = btoa(
|
||||
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
|
||||
String.fromCharCode(parseInt("0x" + p1, 16))
|
||||
)
|
||||
);
|
||||
} else {
|
||||
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||
}
|
||||
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||
}
|
||||
|
||||
type Credential = {
|
||||
projectId: string;
|
||||
clientEmail: string;
|
||||
region: string;
|
||||
privateKey: string;
|
||||
};
|
||||
|
||||
function getCredentialParts(req: Readonly<Request>): Credential {
|
||||
const [projectId, clientEmail, region, rawPrivateKey] =
|
||||
req.key!.key.split(":");
|
||||
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||
req.log.error(
|
||||
{ key: req.key!.hash },
|
||||
"GCP_CREDENTIALS isn't correctly formatted; refer to the docs"
|
||||
);
|
||||
throw new Error("The key assigned to this request is invalid.");
|
||||
}
|
||||
|
||||
const privateKey = rawPrivateKey
|
||||
.replace(
|
||||
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
|
||||
""
|
||||
)
|
||||
.trim();
|
||||
|
||||
return { projectId, clientEmail, region, privateKey };
|
||||
}
|
||||
|
|
|
@ -125,6 +125,9 @@ function pinoLoggerPlugin(proxyServer: ProxyServer<Request>) {
|
|||
target: `${protocol}//${host}${path}`,
|
||||
status: proxyRes.statusCode,
|
||||
contentType: proxyRes.headers["content-type"],
|
||||
contentEncoding: proxyRes.headers["content-encoding"],
|
||||
contentLength: proxyRes.headers["content-length"],
|
||||
transferEncoding: proxyRes.headers["transfer-encoding"],
|
||||
},
|
||||
"Got response from upstream API."
|
||||
);
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import { PassThrough } from "stream";
|
||||
|
||||
const BUFFER_DECODER_MAP = {
|
||||
gzip: util.promisify(zlib.gunzip),
|
||||
deflate: util.promisify(zlib.inflate),
|
||||
br: util.promisify(zlib.brotliDecompress),
|
||||
text: (data: Buffer) => data,
|
||||
};
|
||||
|
||||
const STREAM_DECODER_MAP = {
|
||||
gzip: zlib.createGunzip,
|
||||
deflate: zlib.createInflate,
|
||||
br: zlib.createBrotliDecompress,
|
||||
text: () => new PassThrough(),
|
||||
};
|
||||
|
||||
type SupportedContentEncoding = keyof typeof BUFFER_DECODER_MAP;
|
||||
const isSupportedContentEncoding = (
|
||||
encoding: string
|
||||
): encoding is SupportedContentEncoding => encoding in BUFFER_DECODER_MAP;
|
||||
|
||||
export async function decompressBuffer(buf: Buffer, encoding: string = "text") {
|
||||
if (isSupportedContentEncoding(encoding)) {
|
||||
return (await BUFFER_DECODER_MAP[encoding](buf)).toString();
|
||||
}
|
||||
throw new Error(`Unsupported content-encoding: ${encoding}`);
|
||||
}
|
||||
|
||||
export function getStreamDecompressor(encoding: string = "text") {
|
||||
if (isSupportedContentEncoding(encoding)) {
|
||||
return STREAM_DECODER_MAP[encoding]();
|
||||
}
|
||||
throw new Error(`Unsupported content-encoding: ${encoding}`);
|
||||
}
|
|
@ -1,20 +1,6 @@
|
|||
import { Request, Response } from "express";
|
||||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import { sendProxyError } from "../common";
|
||||
import type { RawResponseBodyHandler } from "./index";
|
||||
|
||||
const DECODER_MAP = {
|
||||
gzip: util.promisify(zlib.gunzip),
|
||||
deflate: util.promisify(zlib.inflate),
|
||||
br: util.promisify(zlib.brotliDecompress),
|
||||
text: (data: Buffer) => data,
|
||||
};
|
||||
type SupportedContentEncoding = keyof typeof DECODER_MAP;
|
||||
|
||||
const isSupportedContentEncoding = (
|
||||
encoding: string
|
||||
): encoding is SupportedContentEncoding => encoding in DECODER_MAP;
|
||||
import { decompressBuffer } from "./compression";
|
||||
|
||||
/**
|
||||
* Handles the response from the upstream service and decodes the body if
|
||||
|
@ -40,36 +26,45 @@ export const handleBlockingResponse: RawResponseBodyHandler = async (
|
|||
let chunks: Buffer[] = [];
|
||||
proxyRes.on("data", (chunk) => chunks.push(chunk));
|
||||
proxyRes.on("end", async () => {
|
||||
const contentEncoding = proxyRes.headers["content-encoding"];
|
||||
const contentType = proxyRes.headers["content-type"];
|
||||
let body: string | Buffer = Buffer.concat(chunks);
|
||||
const rejectWithMessage = function (msg: string, err: Error) {
|
||||
const error = `${msg} (${err.message})`;
|
||||
req.log.warn({ stack: err.stack }, error);
|
||||
req.log.warn(
|
||||
{ msg: error, stack: err.stack },
|
||||
"Error in blocking response handler"
|
||||
);
|
||||
sendProxyError(req, res, 500, "Internal Server Error", { error });
|
||||
return reject(error);
|
||||
};
|
||||
|
||||
const contentEncoding = proxyRes.headers["content-encoding"] ?? "text";
|
||||
if (isSupportedContentEncoding(contentEncoding)) {
|
||||
try {
|
||||
body = (await DECODER_MAP[contentEncoding](body)).toString();
|
||||
} catch (e) {
|
||||
return rejectWithMessage(`Could not decode response body`, e);
|
||||
}
|
||||
} else {
|
||||
return rejectWithMessage(
|
||||
"API responded with unsupported content encoding",
|
||||
new Error(`Unsupported content-encoding: ${contentEncoding}`)
|
||||
);
|
||||
try {
|
||||
body = await decompressBuffer(body, contentEncoding);
|
||||
} catch (e) {
|
||||
return rejectWithMessage(`Could not decode response body`, e);
|
||||
}
|
||||
|
||||
try {
|
||||
if (proxyRes.headers["content-type"]?.includes("application/json")) {
|
||||
return resolve(JSON.parse(body));
|
||||
}
|
||||
return resolve(body);
|
||||
return resolve(tryParseAsJson(body, contentType));
|
||||
} catch (e) {
|
||||
return rejectWithMessage("API responded with invalid JSON", e);
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
function tryParseAsJson(body: string, contentType?: string) {
|
||||
// If the response is declared as JSON, it must parse or we will throw
|
||||
if (contentType?.includes("application/json")) {
|
||||
return JSON.parse(body);
|
||||
}
|
||||
// If it's not declared as JSON, some APIs we'll try to parse it as JSON
|
||||
// anyway since some APIs return the wrong content-type header in some cases.
|
||||
// If it fails to parse, we'll just return the raw body without throwing.
|
||||
try {
|
||||
return JSON.parse(body);
|
||||
} catch (e) {
|
||||
return body;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import express from "express";
|
||||
import { pipeline, Readable, Transform } from "stream";
|
||||
import StreamArray from "stream-json/streamers/StreamArray";
|
||||
import { StringDecoder } from "string_decoder";
|
||||
import { promisify } from "util";
|
||||
import type { logger } from "../../../logger";
|
||||
|
@ -18,6 +17,7 @@ import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder";
|
|||
import { EventAggregator } from "./streaming/event-aggregator";
|
||||
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
|
||||
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||
import { getStreamDecompressor } from "./compression";
|
||||
|
||||
const pipelineAsync = promisify(pipeline);
|
||||
|
||||
|
@ -41,21 +41,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
req,
|
||||
res
|
||||
) => {
|
||||
const { hash } = req.key!;
|
||||
const { headers, statusCode } = proxyRes;
|
||||
if (!req.isStreaming) {
|
||||
throw new Error("handleStreamedResponse called for non-streaming request.");
|
||||
}
|
||||
|
||||
if (proxyRes.statusCode! > 201) {
|
||||
if (statusCode! > 201) {
|
||||
req.isStreaming = false;
|
||||
req.log.warn(
|
||||
{ statusCode: proxyRes.statusCode, key: hash },
|
||||
{ statusCode },
|
||||
`Streaming request returned error status code. Falling back to non-streaming response handler.`
|
||||
);
|
||||
return handleBlockingResponse(proxyRes, req, res);
|
||||
}
|
||||
|
||||
req.log.debug({ headers: proxyRes.headers }, `Starting to proxy SSE stream.`);
|
||||
req.log.debug({ headers }, `Starting to proxy SSE stream.`);
|
||||
|
||||
// Typically, streaming will have already been initialized by the request
|
||||
// queue to send heartbeat pings.
|
||||
|
@ -66,7 +66,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
|
||||
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
||||
const streamOptions = {
|
||||
contentType: proxyRes.headers["content-type"],
|
||||
contentType: headers["content-type"],
|
||||
api: req.outboundApi,
|
||||
logger: req.log,
|
||||
};
|
||||
|
@ -78,11 +78,10 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
// only have to write one aggregator (OpenAI input) for each output format.
|
||||
const aggregator = new EventAggregator(req);
|
||||
|
||||
// Decoder reads from the raw response buffer and produces a stream of
|
||||
// discrete events in some format (text/event-stream, vnd.amazon.event-stream,
|
||||
// streaming JSON, etc).
|
||||
const decompressor = getStreamDecompressor(headers["content-encoding"]);
|
||||
// Decoder reads from the response bytes to produce a stream of plaintext.
|
||||
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
|
||||
// Adapter consumes the decoded events and produces server-sent events so we
|
||||
// Adapter consumes the decoded text and produces server-sent events so we
|
||||
// have a standard event format for the client and to translate between API
|
||||
// message formats.
|
||||
const adapter = new SSEStreamAdapter(streamOptions);
|
||||
|
@ -107,7 +106,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
|||
try {
|
||||
await Promise.race([
|
||||
handleAbortedStream(req, res),
|
||||
pipelineAsync(proxyRes, decoder, adapter, transformer),
|
||||
pipelineAsync(proxyRes, decompressor, decoder, adapter, transformer),
|
||||
]);
|
||||
req.log.debug(`Finished proxying SSE stream.`);
|
||||
res.end();
|
||||
|
@ -180,8 +179,7 @@ function getDecoder(options: {
|
|||
} else if (contentType?.includes("application/json")) {
|
||||
throw new Error("JSON streaming not supported, request SSE instead");
|
||||
} else {
|
||||
// Passthrough stream, but ensures split chunks across multi-byte characters
|
||||
// are handled correctly.
|
||||
// Ensures split chunks across multi-byte characters are handled correctly.
|
||||
const stringDecoder = new StringDecoder("utf8");
|
||||
return new Transform({
|
||||
readableObjectMode: true,
|
||||
|
|
|
@ -2,7 +2,6 @@ import pino from "pino";
|
|||
import { Transform, TransformOptions } from "stream";
|
||||
import { Message } from "@smithy/eventstream-codec";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { buildSpoofedSSE } from "../error-generator";
|
||||
import { BadRequestError, RetryableError } from "../../../../shared/errors";
|
||||
|
||||
type SSEStreamAdapterOptions = TransformOptions & {
|
||||
|
@ -108,34 +107,6 @@ export class SSEStreamAdapter extends Transform {
|
|||
}
|
||||
}
|
||||
|
||||
/** Processes an incoming array element from the Google AI JSON stream. */
|
||||
protected processGoogleObject(data: any): string | null {
|
||||
// Sometimes data has fields key and value, sometimes it's just the
|
||||
// candidates array.
|
||||
const candidates = data.value?.candidates ?? data.candidates ?? [{}];
|
||||
try {
|
||||
const hasParts = candidates[0].content?.parts?.length > 0;
|
||||
if (hasParts) {
|
||||
return `data: ${JSON.stringify(data.value ?? data)}`;
|
||||
} else {
|
||||
this.log.error({ event: data }, "Received bad Google AI event");
|
||||
return `data: ${buildSpoofedSSE({
|
||||
format: "google-ai",
|
||||
title: "Proxy stream error",
|
||||
message:
|
||||
"The proxy received malformed or unexpected data from Google AI while streaming.",
|
||||
obj: data,
|
||||
reqId: "proxy-sse-adapter-message",
|
||||
model: "",
|
||||
})}`;
|
||||
}
|
||||
} catch (error) {
|
||||
error.lastEvent = data;
|
||||
this.emit("error", error);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
_transform(data: any, _enc: string, callback: (err?: Error | null) => void) {
|
||||
try {
|
||||
if (this.isAwsStream) {
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import { AxiosError } from "axios";
|
||||
import crypto from "crypto";
|
||||
import { GcpModelFamily } from "../../models";
|
||||
import { getAxiosInstance } from "../../network";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { GcpKey, GcpKeyProvider } from "./provider";
|
||||
import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "./oauth";
|
||||
|
||||
const axios = getAxiosInstance();
|
||||
|
||||
|
@ -37,6 +37,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||
let checks: Promise<boolean>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
await this.maybeRefreshAccessToken(key);
|
||||
checks = [
|
||||
this.invokeModel("claude-3-haiku@20240307", key, true),
|
||||
this.invokeModel("claude-3-sonnet@20240229", key, true),
|
||||
|
@ -70,6 +71,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||
modelFamilies: families,
|
||||
});
|
||||
} else {
|
||||
await this.maybeRefreshAccessToken(key);
|
||||
if (key.haikuEnabled) {
|
||||
await this.invokeModel("claude-3-haiku@20240307", key, false);
|
||||
} else if (key.sonnetEnabled) {
|
||||
|
@ -85,10 +87,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||
}
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
families: key.modelFamilies,
|
||||
},
|
||||
{ key: key.hash, families: key.modelFamilies },
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
@ -129,26 +128,36 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
private async maybeRefreshAccessToken(key: GcpKey) {
|
||||
if (key.accessToken && key.accessTokenExpiresAt >= Date.now()) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.info({ key: key.hash }, "Refreshing GCP access token...");
|
||||
const [token, durationSec] = await refreshGcpAccessToken(key);
|
||||
this.updateKey(key.hash, {
|
||||
accessToken: token,
|
||||
accessTokenExpiresAt: Date.now() + durationSec * 1000 * 0.95,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to invoke the given model with the given key. Returns true if the
|
||||
* key has access to the model, false if it does not. Throws an error if the
|
||||
* key is disabled.
|
||||
*/
|
||||
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
|
||||
const creds = GcpKeyChecker.getCredentialsFromKey(key);
|
||||
const signedJWT = await GcpKeyChecker.createSignedJWT(
|
||||
creds.clientEmail,
|
||||
creds.privateKey
|
||||
);
|
||||
const [accessToken, jwtError] =
|
||||
await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT);
|
||||
if (accessToken === null) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, jwtError },
|
||||
"Unable to get the access token"
|
||||
const creds = await getCredentialsFromGcpKey(key);
|
||||
try {
|
||||
await this.maybeRefreshAccessToken(key);
|
||||
} catch (e) {
|
||||
this.log.error(
|
||||
{ key: key.hash, error: e.message },
|
||||
"Could not test key due to error while getting access token."
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
const payload = {
|
||||
max_tokens: 1,
|
||||
messages: TEST_MESSAGES,
|
||||
|
@ -158,7 +167,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
|
||||
payload,
|
||||
{
|
||||
headers: GcpKeyChecker.getRequestHeaders(accessToken),
|
||||
headers: GcpKeyChecker.getRequestHeaders(key.accessToken),
|
||||
validateStatus: initial
|
||||
? () => true
|
||||
: (status: number) => status >= 200 && status < 300,
|
||||
|
@ -184,114 +193,10 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||
}
|
||||
}
|
||||
|
||||
static async createSignedJWT(email: string, pkey: string): Promise<string> {
|
||||
let cryptoKey = await crypto.subtle.importKey(
|
||||
"pkcs8",
|
||||
GcpKeyChecker.str2ab(atob(pkey)),
|
||||
{ name: "RSASSA-PKCS1-v1_5", hash: { name: "SHA-256" } },
|
||||
false,
|
||||
["sign"]
|
||||
);
|
||||
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const issued = Math.floor(Date.now() / 1000);
|
||||
const expires = issued + 600;
|
||||
|
||||
const header = { alg: "RS256", typ: "JWT" };
|
||||
|
||||
const payload = {
|
||||
iss: email,
|
||||
aud: authUrl,
|
||||
iat: issued,
|
||||
exp: expires,
|
||||
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||
};
|
||||
|
||||
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(
|
||||
JSON.stringify(header)
|
||||
);
|
||||
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(
|
||||
JSON.stringify(payload)
|
||||
);
|
||||
|
||||
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||
|
||||
const signature = await crypto.subtle.sign(
|
||||
"RSASSA-PKCS1-v1_5",
|
||||
cryptoKey,
|
||||
GcpKeyChecker.str2ab(unsignedToken)
|
||||
);
|
||||
|
||||
const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature);
|
||||
return `${unsignedToken}.${encodedSignature}`;
|
||||
}
|
||||
|
||||
static async exchangeJwtForAccessToken(
|
||||
signed_jwt: string
|
||||
): Promise<[string | null, string]> {
|
||||
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const params = {
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
assertion: signed_jwt,
|
||||
};
|
||||
|
||||
const r = await fetch(auth_url, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: Object.entries(params)
|
||||
.map(([k, v]) => `${k}=${v}`)
|
||||
.join("&"),
|
||||
}).then((res) => res.json());
|
||||
|
||||
if (r.access_token) {
|
||||
return [r.access_token, ""];
|
||||
}
|
||||
|
||||
return [null, JSON.stringify(r)];
|
||||
}
|
||||
|
||||
static str2ab(str: string): ArrayBuffer {
|
||||
const buffer = new ArrayBuffer(str.length);
|
||||
const bufferView = new Uint8Array(buffer);
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
bufferView[i] = str.charCodeAt(i);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||
let base64: string;
|
||||
if (typeof data === "string") {
|
||||
base64 = btoa(
|
||||
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
|
||||
String.fromCharCode(parseInt("0x" + p1, 16))
|
||||
)
|
||||
);
|
||||
} else {
|
||||
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||
}
|
||||
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||
}
|
||||
|
||||
static getRequestHeaders(accessToken: string) {
|
||||
return {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
}
|
||||
|
||||
static getCredentialsFromKey(key: GcpKey) {
|
||||
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
|
||||
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||
throw new Error("Invalid GCP key");
|
||||
}
|
||||
const privateKey = rawPrivateKey
|
||||
.replace(
|
||||
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
|
||||
""
|
||||
)
|
||||
.trim();
|
||||
|
||||
return { projectId, clientEmail, region, privateKey };
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
import crypto from "crypto";
|
||||
import type { GcpKey } from "./provider";
|
||||
import { getAxiosInstance } from "../../network";
|
||||
import { logger } from "../../../logger";
|
||||
|
||||
const axios = getAxiosInstance();
|
||||
const log = logger.child({ module: "gcp-oauth" });
|
||||
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const scope = "https://www.googleapis.com/auth/cloud-platform";
|
||||
|
||||
type GoogleAuthResponse = {
|
||||
access_token: string;
|
||||
scope: string;
|
||||
token_type: "Bearer";
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
type GoogleAuthError = {
|
||||
error:
|
||||
| "unauthorized_client"
|
||||
| "access_denied"
|
||||
| "admin_policy_enforced"
|
||||
| "invalid_client"
|
||||
| "invalid_grant"
|
||||
| "invalid_scope"
|
||||
| "disabled_client"
|
||||
| "org_internal";
|
||||
error_description: string;
|
||||
};
|
||||
|
||||
export async function refreshGcpAccessToken(
|
||||
key: GcpKey
|
||||
): Promise<[string, number]> {
|
||||
log.info({ key: key.hash }, "Entering GCP OAuth flow...");
|
||||
const { clientEmail, privateKey } = await getCredentialsFromGcpKey(key);
|
||||
|
||||
// https://developers.google.com/identity/protocols/oauth2/service-account#authorizingrequests
|
||||
const jwt = await createSignedJWT(clientEmail, privateKey);
|
||||
log.info({ key: key.hash }, "Signed JWT, exchanging for access token...");
|
||||
const res = await axios.post<GoogleAuthResponse | GoogleAuthError>(
|
||||
authUrl,
|
||||
{
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
assertion: jwt,
|
||||
},
|
||||
{
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
validateStatus: () => true,
|
||||
}
|
||||
);
|
||||
const status = res.status;
|
||||
const headers = res.headers;
|
||||
const data = res.data;
|
||||
|
||||
if ("error" in data || status >= 400) {
|
||||
log.error(
|
||||
{ key: key.hash, status, headers, data },
|
||||
"Error from Google Identity API while getting access token."
|
||||
);
|
||||
throw new Error(
|
||||
`Google Identity API returned error: ${(data as GoogleAuthError).error}`
|
||||
);
|
||||
}
|
||||
|
||||
log.info({ key: key.hash, exp: data.expires_in }, "Got access token.");
|
||||
return [data.access_token, data.expires_in];
|
||||
}
|
||||
|
||||
export async function getCredentialsFromGcpKey(key: GcpKey) {
|
||||
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
|
||||
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||
log.error(
|
||||
{ key: key.hash },
|
||||
"Cannot parse GCP credentials. Ensure they are in the format PROJECT_ID:CLIENT_EMAIL:REGION:PRIVATE_KEY, and ensure no whitespace or newlines are in the private key."
|
||||
);
|
||||
throw new Error("Cannot parse GCP credentials.");
|
||||
}
|
||||
|
||||
if (!key.privateKey) {
|
||||
await importPrivateKey(key, rawPrivateKey);
|
||||
}
|
||||
|
||||
return { projectId, clientEmail, region, privateKey: key.privateKey! };
|
||||
}
|
||||
|
||||
async function createSignedJWT(
|
||||
email: string,
|
||||
pkey: crypto.webcrypto.CryptoKey
|
||||
) {
|
||||
const issued = Math.floor(Date.now() / 1000);
|
||||
const expires = issued + 600;
|
||||
|
||||
const header = { alg: "RS256", typ: "JWT" };
|
||||
|
||||
const payload = {
|
||||
iss: email,
|
||||
aud: authUrl,
|
||||
iat: issued,
|
||||
exp: expires,
|
||||
scope,
|
||||
};
|
||||
|
||||
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
|
||||
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
|
||||
|
||||
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||
|
||||
const signature = await crypto.subtle.sign(
|
||||
"RSASSA-PKCS1-v1_5",
|
||||
pkey,
|
||||
new TextEncoder().encode(unsignedToken)
|
||||
);
|
||||
|
||||
const encodedSignature = urlSafeBase64Encode(signature);
|
||||
return `${unsignedToken}.${encodedSignature}`;
|
||||
}
|
||||
|
||||
async function importPrivateKey(key: GcpKey, rawPrivateKey: string) {
|
||||
log.info({ key: key.hash }, "Importing GCP private key...");
|
||||
const privateKey = rawPrivateKey
|
||||
.replace(
|
||||
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
|
||||
""
|
||||
)
|
||||
.trim();
|
||||
const binaryKey = Buffer.from(privateKey, "base64");
|
||||
key.privateKey = await crypto.subtle.importKey(
|
||||
"pkcs8",
|
||||
binaryKey,
|
||||
{ name: "RSASSA-PKCS1-v1_5", hash: "SHA-256" },
|
||||
true,
|
||||
["sign"]
|
||||
);
|
||||
log.info({ key: key.hash }, "GCP private key imported.");
|
||||
}
|
||||
|
||||
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||
let base64: string;
|
||||
if (typeof data === "string") {
|
||||
base64 = btoa(
|
||||
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
|
||||
String.fromCharCode(parseInt("0x" + p1, 16))
|
||||
)
|
||||
);
|
||||
} else {
|
||||
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||
}
|
||||
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||
}
|
|
@ -17,6 +17,11 @@ export interface GcpKey extends Key, GcpKeyUsage {
|
|||
sonnetEnabled: boolean;
|
||||
haikuEnabled: boolean;
|
||||
sonnet35Enabled: boolean;
|
||||
|
||||
privateKey?: crypto.webcrypto.CryptoKey;
|
||||
/** Cached access token for GCP APIs. */
|
||||
accessToken: string;
|
||||
accessTokenExpiresAt: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -68,6 +73,8 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
|||
sonnetEnabled: true,
|
||||
haikuEnabled: false,
|
||||
sonnet35Enabled: false,
|
||||
accessToken: "",
|
||||
accessTokenExpiresAt: 0,
|
||||
["gcp-claudeTokens"]: 0,
|
||||
["gcp-claude-opusTokens"]: 0,
|
||||
};
|
||||
|
|
|
@ -8,13 +8,13 @@ import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models";
|
|||
import { Key, KeyProvider } from "./index";
|
||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { GcpKeyProvider } from "./gcp/provider";
|
||||
import { GcpKeyProvider, GcpKey } from "./gcp/provider";
|
||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||
import { MistralAIKeyProvider } from "./mistral-ai/provider";
|
||||
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate | Partial<GcpKey>;
|
||||
|
||||
export class KeyPool {
|
||||
private keyProviders: KeyProvider[] = [];
|
||||
|
|
Loading…
Reference in New Issue