diff --git a/src/shared/key-management/gcp/checker.ts b/src/shared/key-management/gcp/checker.ts index 5995065..67d5b26 100644 --- a/src/shared/key-management/gcp/checker.ts +++ b/src/shared/key-management/gcp/checker.ts @@ -6,10 +6,12 @@ import { GcpModelFamily } from "../../models"; const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes -const GCP_HOST = - process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com"; +const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com"; const POST_STREAM_RAW_URL = (project: string, region: string, model: string) => - `https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`; + `https://${GCP_HOST.replace( + "%REGION%", + region + )}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`; const TEST_MESSAGES = [ { role: "user", content: "Hi!" }, { role: "assistant", content: "Hello!" }, @@ -38,9 +40,8 @@ export class GcpKeyChecker extends KeyCheckerBase { this.invokeModel("claude-3-5-sonnet@20240620", key, true), ]; - const [sonnet, haiku, opus, sonnet35] = - await Promise.all(checks); - + const [sonnet, haiku, opus, sonnet35] = await Promise.all(checks); + this.log.debug( { key: key.hash, sonnet, haiku, opus, sonnet35 }, "GCP model initial tests complete." @@ -66,20 +67,17 @@ export class GcpKeyChecker extends KeyCheckerBase { }); } else { if (key.haikuEnabled) { - await this.invokeModel("claude-3-haiku@20240307", key, false) + await this.invokeModel("claude-3-haiku@20240307", key, false); } else if (key.sonnetEnabled) { - await this.invokeModel("claude-3-sonnet@20240229", key, false) + await this.invokeModel("claude-3-sonnet@20240229", key, false); } else if (key.sonnet35Enabled) { - await this.invokeModel("claude-3-5-sonnet@20240620", key, false) + await this.invokeModel("claude-3-5-sonnet@20240620", key, false); } else { - await this.invokeModel("claude-3-opus@20240229", key, false) + await this.invokeModel("claude-3-opus@20240229", key, false); } this.updateKey(key.hash, { lastChecked: Date.now() }); - this.log.debug( - { key: key.hash}, - "GCP key check complete." - ); + this.log.debug({ key: key.hash }, "GCP key check complete."); } this.log.info( @@ -134,8 +132,12 @@ export class GcpKeyChecker extends KeyCheckerBase { */ private async invokeModel(model: string, key: GcpKey, initial: boolean) { const creds = GcpKeyChecker.getCredentialsFromKey(key); - const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey) - const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT) + const signedJWT = await GcpKeyChecker.createSignedJWT( + creds.clientEmail, + creds.privateKey + ); + const [accessToken, jwtError] = + await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT); if (accessToken === null) { this.log.warn( { key: key.hash, jwtError }, @@ -151,15 +153,19 @@ export class GcpKeyChecker extends KeyCheckerBase { const { data, status } = await axios.post( POST_STREAM_RAW_URL(creds.projectId, creds.region, model), payload, - { + { headers: GcpKeyChecker.getRequestHeaders(accessToken), - validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300 + validateStatus: initial + ? () => true + : (status: number) => status >= 200 && status < 300, } ); this.log.debug({ key: key.hash, data }, "Response from GCP"); if (initial) { - return (status >= 200 && status < 300) || (status === 429 || status === 529); + return ( + (status >= 200 && status < 300) || status === 429 || status === 529 + ); } return true; @@ -203,8 +209,12 @@ export class GcpKeyChecker extends KeyCheckerBase { scope: "https://www.googleapis.com/auth/cloud-platform", }; - const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header)); - const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload)); + const encodedHeader = GcpKeyChecker.urlSafeBase64Encode( + JSON.stringify(header) + ); + const encodedPayload = GcpKeyChecker.urlSafeBase64Encode( + JSON.stringify(payload) + ); const unsignedToken = `${encodedHeader}.${encodedPayload}`; @@ -218,7 +228,9 @@ export class GcpKeyChecker extends KeyCheckerBase { return `${unsignedToken}.${encodedSignature}`; } - static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> { + static async exchangeJwtForAccessToken( + signed_jwt: string + ): Promise<[string | null, string]> { const auth_url = "https://www.googleapis.com/oauth2/v4/token"; const params = { grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer", @@ -252,7 +264,11 @@ export class GcpKeyChecker extends KeyCheckerBase { static urlSafeBase64Encode(data: string | ArrayBuffer): string { let base64: string; if (typeof data === "string") { - base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16)))); + base64 = btoa( + encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => + String.fromCharCode(parseInt("0x" + p1, 16)) + ) + ); } else { base64 = btoa(String.fromCharCode(...new Uint8Array(data))); } @@ -260,7 +276,10 @@ export class GcpKeyChecker extends KeyCheckerBase { } static getRequestHeaders(accessToken: string) { - return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" }; + return { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }; } static getCredentialsFromKey(key: GcpKey) { @@ -269,9 +288,12 @@ export class GcpKeyChecker extends KeyCheckerBase { throw new Error("Invalid GCP key"); } const privateKey = rawPrivateKey - .replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '') + .replace( + /-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, + "" + ) .trim(); - + return { projectId, clientEmail, region, privateKey }; } }