oai-reverse-proxy/src/shared/key-management/aws/checker.ts

255 lines
9.4 KiB
TypeScript

import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
import { URL } from "url";
import { KeyCheckerBase } from "../key-checker-base";
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
const AMZ_HOST =
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
const TEST_PROMPT = "\n\nHuman:\n\nAssistant:";
type AwsError = { error: {} };
type GetLoggingConfigResponse = {
loggingConfig: null | {
cloudWatchConfig: null | unknown;
s3Config: null | unknown;
embeddingDataDeliveryEnabled: boolean;
imageDataDeliveryEnabled: boolean;
textDataDeliveryEnabled: boolean;
};
};
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
super(keys, {
service: "aws",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
updateKey,
});
}
protected async testKeyOrFail(key: AwsBedrockKey) {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
const modelChecks: Promise<unknown>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
modelChecks.push(this.invokeModel("anthropic.claude-v2:1", key));
}
await Promise.all(modelChecks);
await this.checkLoggingConfiguration(key);
this.log.info(
{
key: key.hash,
models: key.modelFamilies,
logged: key.awsLoggingStatus,
},
"Checked key."
);
}
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {
if (error.response && AwsKeyChecker.errorIsAwsError(error)) {
const errorHeader = error.response.headers["x-amzn-errortype"] as string;
const errorType = errorHeader.split(":")[0];
switch (errorType) {
case "AccessDeniedException":
// Indicates that the principal's attached policy does not allow them
// to perform the requested action.
// How we handle this depends on whether the action was one that we
// must be able to perform in order to use the key.
const path = new URL(error.config?.url!).pathname;
const data = error.response.data;
this.log.warn(
{ key: key.hash, type: errorType, path, data },
"Key can't perform a required action; disabling."
);
return this.updateKey(key.hash, { isDisabled: true });
case "UnrecognizedClientException":
// This is a 403 error that indicates the key is revoked.
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is revoked; disabling."
);
return this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
});
case "ThrottlingException":
// This is a 429 error that indicates the key is rate-limited, but
// not necessarily disabled. Retry in 10 seconds.
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is rate limited. Rechecking in 10 seconds."
);
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
return this.updateKey(key.hash, { lastChecked: next });
case "ValidationException":
default:
// This indicates some issue that we did not account for, possibly
// a new ValidationException type. This likely means our key checker
// needs to be updated so we'll just let the key through and let it
// fail when someone tries to use it if the error is fatal.
this.log.error(
{ key: key.hash, errorType, error: error.response.data },
"Encountered unexpected error while checking key. This may indicate a change in the API; please report this."
);
return this.updateKey(key.hash, { lastChecked: Date.now() });
}
}
const { response } = error;
const { headers, status, data } = response ?? {};
this.log.error(
{ key: key.hash, status, headers, data, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
private async invokeModel(model: string, key: AwsBedrockKey) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
// This is not a valid invocation payload, but a 400 response indicates that
// the principal at least has permission to invoke the model.
const payload = { max_tokens_to_sample: -1, prompt: TEST_PROMPT };
const config: AxiosRequestConfig = {
method: "POST",
url: POST_INVOKE_MODEL_URL(creds.region, model),
data: payload,
validateStatus: (status) => status === 400,
};
config.headers = new AxiosHeaders({
"content-type": "application/json",
accept: "*/*",
});
await AwsKeyChecker.signRequestForAws(config, key);
const response = await axios.request(config);
const { data, status, headers } = response;
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
const errorMessage = data?.message;
// We're looking for a specific error type and message here
// "ValidationException"
const correctErrorType = errorType === "ValidationException";
const correctErrorMessage = errorMessage?.match(/max_tokens_to_sample/);
if (!correctErrorType || !correctErrorMessage) {
throw new AxiosError(
`Unexpected error when invoking model ${model}: ${errorMessage}`,
"AWS_ERROR",
response.config,
response.request,
response
);
}
this.log.debug(
{ key: key.hash, errorType, data, status, model },
"Liveness test complete."
);
}
private async checkLoggingConfiguration(key: AwsBedrockKey) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const config: AxiosRequestConfig = {
method: "GET",
url: GET_INVOCATION_LOGGING_CONFIG_URL(creds.region),
headers: { accept: "application/json" },
validateStatus: () => true,
};
await AwsKeyChecker.signRequestForAws(config, key);
const { data, status, headers } =
await axios.request<GetLoggingConfigResponse>(config);
let result: AwsBedrockKey["awsLoggingStatus"] = "unknown";
if (status === 200) {
const { loggingConfig } = data;
const loggingEnabled = !!loggingConfig?.textDataDeliveryEnabled;
this.log.debug(
{ key: key.hash, loggingConfig, loggingEnabled },
"AWS model invocation logging test complete."
);
result = loggingEnabled ? "enabled" : "disabled";
} else {
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
this.log.debug(
{ key: key.hash, errorType, data, status },
"Can't determine AWS model invocation logging status."
);
}
this.updateKey(key.hash, { awsLoggingStatus: result });
}
static errorIsAwsError(error: AxiosError): error is AxiosError<AwsError> {
const headers = error.response?.headers;
if (!headers) return false;
return !!headers["x-amzn-errortype"];
}
/** Given an Axios request, sign it with the given key. */
static async signRequestForAws(
axiosRequest: AxiosRequestConfig,
key: AwsBedrockKey,
awsService = "bedrock"
) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const { accessKeyId, secretAccessKey, region } = creds;
const { method, url: axUrl, headers: axHeaders, data } = axiosRequest;
const url = new URL(axUrl!);
let plainHeaders = {};
if (axHeaders instanceof AxiosHeaders) {
plainHeaders = axHeaders.toJSON();
} else if (typeof axHeaders === "object") {
plainHeaders = axHeaders;
}
const request = new HttpRequest({
method,
protocol: "https:",
hostname: url.hostname,
path: url.pathname + url.search,
headers: { Host: url.hostname, ...plainHeaders },
});
if (data) {
request.body = JSON.stringify(data);
}
const signer = new SignatureV4({
sha256: Sha256,
credentials: { accessKeyId, secretAccessKey },
region,
service: awsService,
});
const signedRequest = await signer.sign(request);
axiosRequest.headers = signedRequest.headers;
}
static getCredentialsFromKey(key: AwsBedrockKey) {
const [accessKeyId, secretAccessKey, region] = key.key.split(":");
if (!accessKeyId || !secretAccessKey || !region) {
throw new Error("Invalid AWS Bedrock key");
}
return { accessKeyId, secretAccessKey, region };
}
}