applies prettier to GCP checker

This commit is contained in:
nai-degen 2024-08-29 15:15:56 -05:00
parent ee61f9be2b
commit cf615ee62c
1 changed files with 48 additions and 26 deletions

View File

@ -6,10 +6,12 @@ import { GcpModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const GCP_HOST = const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) => const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
`https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`; `https://${GCP_HOST.replace(
"%REGION%",
region
)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
const TEST_MESSAGES = [ const TEST_MESSAGES = [
{ role: "user", content: "Hi!" }, { role: "user", content: "Hi!" },
{ role: "assistant", content: "Hello!" }, { role: "assistant", content: "Hello!" },
@ -38,9 +40,8 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
this.invokeModel("claude-3-5-sonnet@20240620", key, true), this.invokeModel("claude-3-5-sonnet@20240620", key, true),
]; ];
const [sonnet, haiku, opus, sonnet35] = const [sonnet, haiku, opus, sonnet35] = await Promise.all(checks);
await Promise.all(checks);
this.log.debug( this.log.debug(
{ key: key.hash, sonnet, haiku, opus, sonnet35 }, { key: key.hash, sonnet, haiku, opus, sonnet35 },
"GCP model initial tests complete." "GCP model initial tests complete."
@ -66,20 +67,17 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
}); });
} else { } else {
if (key.haikuEnabled) { if (key.haikuEnabled) {
await this.invokeModel("claude-3-haiku@20240307", key, false) await this.invokeModel("claude-3-haiku@20240307", key, false);
} else if (key.sonnetEnabled) { } else if (key.sonnetEnabled) {
await this.invokeModel("claude-3-sonnet@20240229", key, false) await this.invokeModel("claude-3-sonnet@20240229", key, false);
} else if (key.sonnet35Enabled) { } else if (key.sonnet35Enabled) {
await this.invokeModel("claude-3-5-sonnet@20240620", key, false) await this.invokeModel("claude-3-5-sonnet@20240620", key, false);
} else { } else {
await this.invokeModel("claude-3-opus@20240229", key, false) await this.invokeModel("claude-3-opus@20240229", key, false);
} }
this.updateKey(key.hash, { lastChecked: Date.now() }); this.updateKey(key.hash, { lastChecked: Date.now() });
this.log.debug( this.log.debug({ key: key.hash }, "GCP key check complete.");
{ key: key.hash},
"GCP key check complete."
);
} }
this.log.info( this.log.info(
@ -134,8 +132,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
*/ */
private async invokeModel(model: string, key: GcpKey, initial: boolean) { private async invokeModel(model: string, key: GcpKey, initial: boolean) {
const creds = GcpKeyChecker.getCredentialsFromKey(key); const creds = GcpKeyChecker.getCredentialsFromKey(key);
const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey) const signedJWT = await GcpKeyChecker.createSignedJWT(
const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT) creds.clientEmail,
creds.privateKey
);
const [accessToken, jwtError] =
await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT);
if (accessToken === null) { if (accessToken === null) {
this.log.warn( this.log.warn(
{ key: key.hash, jwtError }, { key: key.hash, jwtError },
@ -151,15 +153,19 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
const { data, status } = await axios.post( const { data, status } = await axios.post(
POST_STREAM_RAW_URL(creds.projectId, creds.region, model), POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
payload, payload,
{ {
headers: GcpKeyChecker.getRequestHeaders(accessToken), headers: GcpKeyChecker.getRequestHeaders(accessToken),
validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300 validateStatus: initial
? () => true
: (status: number) => status >= 200 && status < 300,
} }
); );
this.log.debug({ key: key.hash, data }, "Response from GCP"); this.log.debug({ key: key.hash, data }, "Response from GCP");
if (initial) { if (initial) {
return (status >= 200 && status < 300) || (status === 429 || status === 529); return (
(status >= 200 && status < 300) || status === 429 || status === 529
);
} }
return true; return true;
@ -203,8 +209,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
scope: "https://www.googleapis.com/auth/cloud-platform", scope: "https://www.googleapis.com/auth/cloud-platform",
}; };
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header)); const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload)); JSON.stringify(header)
);
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(payload)
);
const unsignedToken = `${encodedHeader}.${encodedPayload}`; const unsignedToken = `${encodedHeader}.${encodedPayload}`;
@ -218,7 +228,9 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
return `${unsignedToken}.${encodedSignature}`; return `${unsignedToken}.${encodedSignature}`;
} }
static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> { static async exchangeJwtForAccessToken(
signed_jwt: string
): Promise<[string | null, string]> {
const auth_url = "https://www.googleapis.com/oauth2/v4/token"; const auth_url = "https://www.googleapis.com/oauth2/v4/token";
const params = { const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer", grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
@ -252,7 +264,11 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
static urlSafeBase64Encode(data: string | ArrayBuffer): string { static urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string; let base64: string;
if (typeof data === "string") { if (typeof data === "string") {
base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16)))); base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
} else { } else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data))); base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
} }
@ -260,7 +276,10 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
} }
static getRequestHeaders(accessToken: string) { static getRequestHeaders(accessToken: string) {
return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" }; return {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
};
} }
static getCredentialsFromKey(key: GcpKey) { static getCredentialsFromKey(key: GcpKey) {
@ -269,9 +288,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
throw new Error("Invalid GCP key"); throw new Error("Invalid GCP key");
} }
const privateKey = rawPrivateKey const privateKey = rawPrivateKey
.replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '') .replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.trim(); .trim();
return { projectId, clientEmail, region, privateKey }; return { projectId, clientEmail, region, privateKey };
} }
} }