Compare commits

...

5 Commits

Author SHA1 Message Date
nai-degen 3f9fd25004 exempt 'special' token type from context size limits 2024-03-19 11:14:51 -05:00
nai-degen e068edcf48 adds Anthropic key tier detection and trial key display 2024-03-18 15:20:34 -05:00
nai-degen 2098948b7a reduces Anthropic keychecker frequency 2024-03-18 15:19:41 -05:00
nai-degen 7705ee58a0 minor cleanup of error-generator.ts 2024-03-18 15:18:18 -05:00
nai-degen 7c64d9209e minor refactoring of response middleware handlers 2024-03-17 22:20:39 -05:00
18 changed files with 285 additions and 192 deletions

View File

@ -3,10 +3,10 @@ import http from "http";
import httpProxy from "http-proxy";
import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error";
import { HttpError } from "../../shared/errors";
import { assertNever } from "../../shared/utils";
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
import { sendErrorToClient } from "./response/error-generator";
import { HttpError } from "../../shared/errors";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions";

View File

@ -31,7 +31,10 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
}
case "anthropic-chat": {
req.outputTokens = req.body.max_tokens;
const prompt: AnthropicChatMessage[] = req.body.messages;
const prompt = {
system: req.body.system ?? "",
messages: req.body.messages,
};
result = await countTokens({ req, prompt, service });
break;
}

View File

@ -46,6 +46,11 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
}
proxyMax ||= Number.MAX_SAFE_INTEGER;
if (req.user?.type === "special") {
req.log.debug("Special user, not enforcing proxy context limit.");
proxyMax = Number.MAX_SAFE_INTEGER;
}
let modelMax: number;
if (model.match(/gpt-3.5-turbo-16k/)) {
modelMax = 16384;

View File

@ -31,17 +31,24 @@ function getMessageContent({
}
```
*/
const note = obj?.proxy_note || obj?.error?.message || "";
const header = `**${title}**`;
const friendlyMessage = note ? `${message}\n\n***\n\n*${note}*` : message;
const details = JSON.parse(JSON.stringify(obj ?? {}));
let stack = "";
if (details.stack) {
stack = `\n\nInclude this trace when reporting an issue.\n\`\`\`\n${details.stack}\n\`\`\``;
delete details.stack;
const serializedObj = obj ? "```" + JSON.stringify(obj, null, 2) + "```" : "";
const { stack } = JSON.parse(JSON.stringify(obj ?? {}));
let prettyTrace = "";
if (stack && obj) {
prettyTrace = [
"Include this trace when reporting an issue.",
"```",
stack,
"```",
].join("\n");
delete obj.stack;
}
return `\n\n**${title}**\n${friendlyMessage}${
obj ? `\n\`\`\`\n${JSON.stringify(obj, null, 2)}\n\`\`\`\n${stack}` : ""
}`;
return [header, friendlyMessage, serializedObj, prettyTrace].join("\n\n");
}
type ErrorGeneratorOptions = {

View File

@ -0,0 +1,76 @@
import util from "util";
import zlib from "zlib";
import { sendProxyError } from "../common";
import type { RawResponseBodyHandler } from "./index";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
br: util.promisify(zlib.brotliDecompress),
};
const isSupportedContentEncoding = (
contentEncoding: string
): contentEncoding is keyof typeof DECODER_MAP => {
return contentEncoding in DECODER_MAP;
};
/**
* Handles the response from the upstream service and decodes the body if
* necessary. If the response is JSON, it will be parsed and returned as an
* object. Otherwise, it will be returned as a string. Does not handle streaming
* responses.
* @throws {Error} Unsupported content-encoding or invalid application/json body
*/
export const handleBlockingResponse: RawResponseBodyHandler = async (
proxyRes,
req,
res
) => {
if (req.isStreaming) {
const err = new Error(
"handleBlockingResponse called for a streaming request."
);
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
throw err;
}
return new Promise<string>((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
let body = Buffer.concat(chunks);
const contentEncoding = proxyRes.headers["content-encoding"];
if (contentEncoding) {
if (isSupportedContentEncoding(contentEncoding)) {
const decoder = DECODER_MAP[contentEncoding];
// @ts-ignore - started failing after upgrading TypeScript, don't care
// as it was never a problem.
body = await decoder(body);
} else {
const error = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, error);
sendProxyError(req, res, 500, "Internal Server Error", {
error,
contentEncoding,
});
return reject(error);
}
}
try {
if (proxyRes.headers["content-type"]?.includes("application/json")) {
const json = JSON.parse(body.toString());
return resolve(json);
}
return resolve(body.toString());
} catch (e) {
const msg = `Proxy received response with invalid JSON: ${e.message}`;
req.log.warn({ error: e.stack, key: req.key?.hash }, msg);
sendProxyError(req, res, 500, "Internal Server Error", { error: msg });
return reject(msg);
}
});
});
};

View File

@ -3,20 +3,21 @@ import { pipeline, Readable, Transform } from "stream";
import StreamArray from "stream-json/streamers/StreamArray";
import { StringDecoder } from "string_decoder";
import { promisify } from "util";
import type { logger } from "../../../logger";
import { BadRequestError, RetryableError } from "../../../shared/errors";
import { APIFormat, keyPool } from "../../../shared/key-management";
import {
copySseResponseHeaders,
initializeSseStream,
} from "../../../shared/streaming";
import type { logger } from "../../../logger";
import { enqueue } from "../../queue";
import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
import { reenqueueRequest } from "../../queue";
import type { RawResponseBodyHandler } from ".";
import { handleBlockingResponse } from "./handle-blocking-response";
import { buildSpoofedSSE, sendErrorToClient } from "./error-generator";
import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder";
import { EventAggregator } from "./streaming/event-aggregator";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { buildSpoofedSSE, sendErrorToClient } from "./error-generator";
import { BadRequestError } from "../../../shared/errors";
const pipelineAsync = promisify(pipeline);
@ -50,7 +51,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
{ statusCode: proxyRes.statusCode, key: hash },
`Streaming request returned error status code. Falling back to non-streaming response handler.`
);
return decodeResponseBody(proxyRes, req, res);
return handleBlockingResponse(proxyRes, req, res);
}
req.log.debug({ headers: proxyRes.headers }, `Starting to proxy SSE stream.`);
@ -105,12 +106,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
} catch (err) {
if (err instanceof RetryableError) {
keyPool.markRateLimited(req.key!);
req.log.warn(
{ key: req.key!.hash, retryCount: req.retryCount },
`Re-enqueueing request due to retryable error during streaming response.`
);
req.retryCount++;
await enqueue(req);
await reenqueueRequest(req);
} else if (err instanceof BadRequestError) {
sendErrorToClient({
req,
@ -138,7 +134,17 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
res.write(`data: [DONE]\n\n`);
res.end();
}
throw err;
// At this point the response is closed. If the request resulted in any
// tokens being consumed (suggesting a mid-stream error), we will resolve
// and continue the middleware chain so tokens can be counted.
if (aggregator.hasEvents()) {
return aggregator.getFinalResponse();
} else {
// If there is nothing, then this was a completely failed prompt that
// will not have billed any tokens. Throw to stop the middleware chain.
throw err;
}
}
};

View File

@ -1,10 +1,8 @@
/* This file is fucking horrendous, sorry */
import { Request, Response } from "express";
import * as http from "http";
import util from "util";
import zlib from "zlib";
import { enqueue, trackWaitTime } from "../../queue";
import { HttpError } from "../../../shared/errors";
import { config } from "../../../config";
import { HttpError, RetryableError } from "../../../shared/errors";
import { keyPool } from "../../../shared/key-management";
import { getOpenAIModelFamily } from "../../../shared/models";
import { countTokens } from "../../../shared/tokenization";
@ -13,6 +11,7 @@ import {
incrementTokenCount,
} from "../../../shared/users/user-store";
import { assertNever } from "../../../shared/utils";
import { reenqueueRequest, trackWaitTime } from "../../queue";
import { refundLastAttempt } from "../../rate-limit";
import {
getCompletionFromBody,
@ -20,39 +19,22 @@ import {
isTextGenerationRequest,
sendProxyError,
} from "../common";
import { handleBlockingResponse } from "./handle-blocking-response";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
import { saveImage } from "./save-image";
import { config } from "../../../config";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
br: util.promisify(zlib.brotliDecompress),
};
const isSupportedContentEncoding = (
contentEncoding: string
): contentEncoding is keyof typeof DECODER_MAP => {
return contentEncoding in DECODER_MAP;
};
export class RetryableError extends Error {
constructor(message: string) {
super(message);
this.name = "RetryableError";
}
}
/**
* Either decodes or streams the entire response body and then passes it as the
* last argument to the rest of the middleware stack.
* Either decodes or streams the entire response body and then resolves with it.
* @returns The response body as a string or parsed JSON object depending on the
* response's content-type.
*/
export type RawResponseBodyHandler = (
proxyRes: http.IncomingMessage,
req: Request,
res: Response
) => Promise<string | Record<string, any>>;
export type ProxyResHandlerWithBody = (
proxyRes: http.IncomingMessage,
req: Request,
@ -76,6 +58,10 @@ export type ProxyResMiddleware = ProxyResHandlerWithBody[];
* middleware from executing as it consumes the stream and forwards events to
* the client. Once the stream is closed, the finalized body will be attached
* to res.body and the remaining middleware will execute.
*
* @param apiMiddleware - Custom middleware to execute after the common response
* handlers. These *only* execute for non-streaming responses, so should be used
* to transform non-streaming responses into the desired format.
*/
export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
return async (
@ -83,30 +69,27 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
req: Request,
res: Response
) => {
const initialHandler = req.isStreaming
const initialHandler: RawResponseBodyHandler = req.isStreaming
? handleStreamedResponse
: decodeResponseBody;
: handleBlockingResponse;
let lastMiddleware = initialHandler.name;
try {
const body = await initialHandler(proxyRes, req, res);
const middlewareStack: ProxyResMiddleware = [];
if (req.isStreaming) {
// `handleStreamedResponse` writes to the response and ends it, so
// we can only execute middleware that doesn't write to the response.
// Handlers for streaming requests must never write to the response.
middlewareStack.push(
trackRateLimit,
trackKeyRateLimit,
countResponseTokens,
incrementUsage,
logPrompt
);
} else {
middlewareStack.push(
trackRateLimit,
addProxyInfo,
trackKeyRateLimit,
injectProxyInfo,
handleUpstreamErrors,
countResponseTokens,
incrementUsage,
@ -154,72 +137,6 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
};
};
async function reenqueueRequest(req: Request) {
req.log.info(
{ key: req.key?.hash, retryCount: req.retryCount },
`Re-enqueueing request due to retryable error`
);
req.retryCount++;
await enqueue(req);
}
/**
* Handles the response from the upstream service and decodes the body if
* necessary. If the response is JSON, it will be parsed and returned as an
* object. Otherwise, it will be returned as a string.
* @throws {Error} Unsupported content-encoding or invalid application/json body
*/
export const decodeResponseBody: RawResponseBodyHandler = async (
proxyRes,
req,
res
) => {
if (req.isStreaming) {
const err = new Error("decodeResponseBody called for a streaming request.");
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
throw err;
}
return new Promise<string>((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
let body = Buffer.concat(chunks);
const contentEncoding = proxyRes.headers["content-encoding"];
if (contentEncoding) {
if (isSupportedContentEncoding(contentEncoding)) {
const decoder = DECODER_MAP[contentEncoding];
// @ts-ignore - started failing after upgrading TypeScript, don't care
// as it was never a problem.
body = await decoder(body);
} else {
const error = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, error);
sendProxyError(req, res, 500, "Internal Server Error", {
error,
contentEncoding,
});
return reject(error);
}
}
try {
if (proxyRes.headers["content-type"]?.includes("application/json")) {
const json = JSON.parse(body.toString());
return resolve(json);
}
return resolve(body.toString());
} catch (e) {
const msg = `Proxy received response with invalid JSON: ${e.message}`;
req.log.warn({ error: e.stack, key: req.key?.hash }, msg);
sendProxyError(req, res, 500, "Internal Server Error", { error: msg });
return reject(msg);
}
});
});
};
type ProxiedErrorPayload = {
error?: Record<string, any>;
message?: string;
@ -242,15 +159,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
) => {
const statusCode = proxyRes.statusCode || 500;
const statusMessage = proxyRes.statusMessage || "Internal Server Error";
if (statusCode < 400) {
return;
}
let errorPayload: ProxiedErrorPayload;
const tryAgainMessage = keyPool.available(req.body?.model)
? `There may be more keys available for this model; try again in a few seconds.`
: "There are no more keys available for this model.";
if (statusCode < 400) return;
try {
assertJsonResponse(body);
@ -303,7 +214,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
// For some reason, some models return this 400 error instead of the
// same 429 billing error that other models return.
await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
await handleOpenAIRateLimitError(req, errorPayload);
} else {
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
}
@ -318,18 +229,18 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (statusCode === 401) {
// Key is invalid or was revoked
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
} else if (statusCode === 403) {
if (service === "anthropic") {
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
return;
}
switch (errorType) {
case "UnrecognizedClientException":
// Key is invalid.
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
break;
case "AccessDeniedException":
const isModelAccessError =
@ -349,7 +260,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (statusCode === 429) {
switch (service) {
case "openai":
await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
await handleOpenAIRateLimitError(req, errorPayload);
break;
case "anthropic":
await handleAnthropicRateLimitError(req, errorPayload);
@ -459,7 +370,7 @@ async function handleAnthropicBadRequestError(
"Anthropic key has been disabled."
);
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been disabled. ${error?.message}`;
errorPayload.proxy_note = `Assigned key has been disabled. (${error?.message})`;
return;
}
@ -499,7 +410,6 @@ async function handleAwsRateLimitError(
async function handleOpenAIRateLimitError(
req: Request,
tryAgainMessage: string,
errorPayload: ProxiedErrorPayload
): Promise<Record<string, any>> {
const type = errorPayload.error?.type;
@ -508,17 +418,17 @@ async function handleOpenAIRateLimitError(
case "invalid_request_error": // this is the billing_hard_limit_reached error seen in some cases
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned key's quota has been exceeded. Please try again.`;
break;
case "access_terminated":
// Account banned (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. Please try again.`;
break;
case "billing_not_active":
// Key valid but account billing is delinquent
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key has been disabled due to delinquent billing. ${tryAgainMessage}`;
errorPayload.proxy_note = `Assigned key has been disabled due to delinquent billing. Please try again.`;
break;
case "requests":
case "tokens":
@ -684,7 +594,7 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
}
};
const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
const trackKeyRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
keyPool.updateRateLimits(req.key!, proxyRes.headers);
};
@ -714,7 +624,7 @@ const copyHttpHeaders: ProxyResHandlerWithBody = async (
* or transformed.
* Only used for non-streaming requests.
*/
const addProxyInfo: ProxyResHandlerWithBody = async (
const injectProxyInfo: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,

View File

@ -67,6 +67,10 @@ export class EventAggregator {
assertNever(this.format);
}
}
hasEvents() {
return this.events.length > 0;
}
}
function eventIsOpenAIEvent(

View File

@ -2,9 +2,8 @@ import pino from "pino";
import { Transform, TransformOptions } from "stream";
import { Message } from "@smithy/eventstream-codec";
import { APIFormat } from "../../../../shared/key-management";
import { RetryableError } from "../index";
import { buildSpoofedSSE } from "../error-generator";
import { BadRequestError } from "../../../../shared/errors";
import { BadRequestError, RetryableError } from "../../../../shared/errors";
type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;

View File

@ -70,7 +70,7 @@ export function generateModelList(models = KNOWN_MISTRAL_AI_MODELS) {
}
const handleModelRequest: RequestHandler = (_req, res) => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60){
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return res.status(200).json(modelsCache);
}
const result = generateModelList();

View File

@ -12,7 +12,7 @@
*/
import crypto from "crypto";
import type { Handler, Request } from "express";
import { Handler, Request } from "express";
import { BadRequestError, TooManyRequestsError } from "../shared/errors";
import { keyPool } from "../shared/key-management";
import {
@ -67,7 +67,7 @@ const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
export async function enqueue(req: Request) {
async function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
let isGuest = req.user?.token === undefined;
@ -136,6 +136,15 @@ export async function enqueue(req: Request) {
}
}
export async function reenqueueRequest(req: Request) {
req.log.info(
{ key: req.key?.hash, retryCount: req.retryCount },
`Re-enqueueing request due to retryable error`
);
req.retryCount++;
await enqueue(req);
}
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue
.filter((req) => getModelFamilyForRequest(req) === partition)

View File

@ -80,6 +80,7 @@ type OpenAIInfo = BaseFamilyInfo & {
overQuotaKeys?: number;
};
type AnthropicInfo = BaseFamilyInfo & {
trialKeys?: number;
prefilledKeys?: number;
overQuotaKeys?: number;
};
@ -349,6 +350,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__trial`, k.tier === "free" ? 1 : 0);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
@ -437,6 +439,7 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
break;
case "anthropic":
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0;
break;
case "aws":

View File

@ -34,3 +34,10 @@ export class TooManyRequestsError extends HttpError {
super(429, message);
}
}
export class RetryableError extends Error {
constructor(message: string) {
super(message);
this.name = "RetryableError";
}
}

View File

@ -1,9 +1,9 @@
import axios, { AxiosError } from "axios";
import axios, { AxiosError, AxiosResponse } from "axios";
import { KeyCheckerBase } from "../key-checker-base";
import type { AnthropicKey, AnthropicKeyProvider } from "./provider";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const KEY_CHECK_PERIOD = 1000 * 60 * 60 * 6; // 6 hours
const POST_MESSAGES_URL = "https://api.anthropic.com/v1/messages";
const TEST_MODEL = "claude-3-sonnet-20240229";
const SYSTEM = "Obey all instructions from the user.";
@ -52,10 +52,13 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
}
protected async testKeyOrFail(key: AnthropicKey) {
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
const updates = { isPozzed: pozzed };
const [{ pozzed, tier }] = await Promise.all([this.testLiveness(key)]);
const updates = { isPozzed: pozzed, tier };
this.updateKey(key.hash, updates);
this.log.info({ key: key.hash, models: key.modelFamilies }, "Checked key.");
this.log.info(
{ key: key.hash, tier, models: key.modelFamilies },
"Checked key."
);
}
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
@ -124,7 +127,9 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
this.updateKey(key.hash, { lastChecked: next });
}
private async testLiveness(key: AnthropicKey): Promise<{ pozzed: boolean }> {
private async testLiveness(
key: AnthropicKey
): Promise<{ pozzed: boolean; tier: AnthropicKey["tier"] }> {
const payload = {
model: TEST_MODEL,
max_tokens: 40,
@ -133,24 +138,27 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
system: SYSTEM,
messages: DETECTION_PROMPT,
};
const { data } = await axios.post<MessageResponse>(
const { data, headers } = await axios.post<MessageResponse>(
POST_MESSAGES_URL,
payload,
{ headers: AnthropicKeyChecker.getHeaders(key) }
{ headers: AnthropicKeyChecker.getRequestHeaders(key) }
);
this.log.debug({ data }, "Response from Anthropic");
const tier = AnthropicKeyChecker.detectTier(headers);
const completion = data.content.map((part) => part.text).join("");
if (POZZ_PROMPT.some((re) => re.test(completion))) {
this.log.info({ key: key.hash, response: completion }, "Key is pozzed.");
return { pozzed: true };
return { pozzed: true, tier };
} else if (COPYRIGHT_PROMPT.some((re) => re.test(completion))) {
this.log.info(
{ key: key.hash, response: completion },
"Key has copyright CYA prompt."
);
return { pozzed: true };
return { pozzed: true, tier };
} else {
return { pozzed: false };
return { pozzed: false, tier };
}
}
@ -161,7 +169,19 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
return data?.error?.type;
}
static getHeaders(key: AnthropicKey) {
static getRequestHeaders(key: AnthropicKey) {
return { "X-API-Key": key.key, "anthropic-version": "2023-06-01" };
}
static detectTier(headers: AxiosResponse["headers"]) {
const tokensLimit = headers["anthropic-ratelimit-tokens-limit"];
const intTokensLimit = parseInt(tokensLimit, 10);
if (!tokensLimit || isNaN(intTokensLimit)) return "unknown";
if (intTokensLimit <= 25000) return "free";
if (intTokensLimit <= 50000) return "build_1";
if (intTokensLimit <= 100000) return "build_2";
if (intTokensLimit <= 200000) return "build_3";
if (intTokensLimit <= 400000) return "build_4";
return "scale";
}
}

View File

@ -4,7 +4,7 @@ import { config } from "../../../config";
import { logger } from "../../../logger";
import { AnthropicModelFamily, getClaudeModelFamily } from "../../models";
import { AnthropicKeyChecker } from "./checker";
import { HttpError, PaymentRequiredError } from "../../errors";
import { PaymentRequiredError } from "../../errors";
export type AnthropicKeyUpdate = Omit<
Partial<AnthropicKey>,
@ -45,13 +45,39 @@ export interface AnthropicKey extends Key, AnthropicKeyUsage {
*/
isPozzed: boolean;
isOverQuota: boolean;
/**
* Key billing tier (https://docs.anthropic.com/claude/reference/rate-limits)
**/
tier: typeof TIER_PRIORITY[number];
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
* Selection priority for Anthropic keys. Aims to maximize throughput by
* saturating concurrency-limited keys first, then trying keys with increasingly
* strict rate limits. Free keys have very limited throughput and are used last.
*/
const RATE_LIMIT_LOCKOUT = 2000;
const TIER_PRIORITY = [
"unknown",
"scale",
"build_4",
"build_3",
"build_2",
"build_1",
"free",
] as const;
/**
* Upon being rate limited, a Scale-tier key will be locked out for this many
* milliseconds while we wait for other concurrent requests to finish.
*/
const SCALE_RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon being rate limited, a Build-tier key will be locked out for this many
* milliseconds while we wait for the per-minute rate limit to reset. Because
* the reset provided in the headers specifies the time for the full quota to
* become available, the key may become available before that time.
*/
const BUILD_RATE_LIMIT_LOCKOUT = 10000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
@ -98,6 +124,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
lastChecked: 0,
claudeTokens: 0,
"claude-opusTokens": 0,
tier: "unknown",
};
this.keys.push(newKey);
}
@ -123,25 +150,27 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
throw new PaymentRequiredError("No Anthropic keys available.");
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 2. Keys which are not pozzed
// 3. Keys which have not been used in the longest time
// 1. Keys which are not rate limit locked
// 2. Keys with the highest tier
// 3. Keys which are not pozzed
// 4. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const aLockoutPeriod = getKeyLockout(a);
const bLockoutPeriod = getKeyLockout(b);
const aRateLimited = now - a.rateLimitedAt < aLockoutPeriod;
const bRateLimited = now - b.rateLimitedAt < bLockoutPeriod;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
const aTierIndex = TIER_PRIORITY.indexOf(a.tier);
const bTierIndex = TIER_PRIORITY.indexOf(b.tier);
if (aTierIndex > bTierIndex) return -1;
if (a.isPozzed && !b.isPozzed) return 1;
if (!a.isPozzed && b.isPozzed) return -1;
@ -207,7 +236,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
key.rateLimitedUntil = now + SCALE_RATE_LIMIT_LOCKOUT;
}
public recheck() {
@ -239,3 +268,9 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}
function getKeyLockout(key: AnthropicKey) {
return ["scale", "unknown"].includes(key.tier)
? SCALE_RATE_LIMIT_LOCKOUT
: BUILD_RATE_LIMIT_LOCKOUT;
}

View File

@ -13,15 +13,15 @@ type KeyCheckerOptions<TKey extends Key = Key> = {
export abstract class KeyCheckerBase<TKey extends Key> {
protected readonly service: string;
protected readonly RECURRING_CHECKS_ENABLED: boolean;
protected readonly recurringChecksEnabled: boolean;
/** Minimum time in between any two key checks. */
protected readonly MIN_CHECK_INTERVAL: number;
protected readonly minCheckInterval: number;
/**
* Minimum time in between checks for a given key. Because we can no longer
* read quota usage, there is little reason to check a single key more often
* than this.
*/
protected readonly KEY_CHECK_PERIOD: number;
protected readonly keyCheckPeriod: number;
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
protected readonly keys: TKey[] = [];
protected log: pino.Logger;
@ -29,14 +29,13 @@ export abstract class KeyCheckerBase<TKey extends Key> {
protected lastCheck = 0;
protected constructor(keys: TKey[], opts: KeyCheckerOptions<TKey>) {
const { service, keyCheckPeriod, minCheckInterval } = opts;
this.keys = keys;
this.KEY_CHECK_PERIOD = keyCheckPeriod;
this.MIN_CHECK_INTERVAL = minCheckInterval;
this.RECURRING_CHECKS_ENABLED = opts.recurringChecksEnabled ?? true;
this.keyCheckPeriod = opts.keyCheckPeriod;
this.minCheckInterval = opts.minCheckInterval;
this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true;
this.updateKey = opts.updateKey;
this.service = service;
this.log = logger.child({ module: "key-checker", service });
this.service = opts.service;
this.log = logger.child({ module: "key-checker", service: opts.service });
}
public start() {
@ -102,7 +101,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
return;
}
if (!this.RECURRING_CHECKS_ENABLED) {
if (!this.recurringChecksEnabled) {
checkLog.info(
"Initial checks complete and recurring checks are disabled for this service. Stopping."
);
@ -117,8 +116,8 @@ export abstract class KeyCheckerBase<TKey extends Key> {
// Don't check any individual key too often.
// Don't check anything at all at a rate faster than once per 3 seconds.
const nextCheck = Math.max(
oldestKey.lastChecked + this.KEY_CHECK_PERIOD,
this.lastCheck + this.MIN_CHECK_INTERVAL
oldestKey.lastChecked + this.keyCheckPeriod,
this.lastCheck + this.minCheckInterval
);
const delay = nextCheck - Date.now();

View File

@ -19,7 +19,9 @@ export function init() {
return true;
}
export async function getTokenCount(prompt: string | AnthropicChatMessage[]) {
export async function getTokenCount(
prompt: string | { system: string; messages: AnthropicChatMessage[] }
) {
if (typeof prompt !== "string") {
return getTokenCountForMessages(prompt);
}
@ -34,9 +36,17 @@ export async function getTokenCount(prompt: string | AnthropicChatMessage[]) {
};
}
async function getTokenCountForMessages(messages: AnthropicChatMessage[]) {
async function getTokenCountForMessages({
system,
messages,
}: {
system: string;
messages: AnthropicChatMessage[];
}) {
let numTokens = 0;
numTokens += (await getTokenCount(system)).token_count;
for (const message of messages) {
const { content, role } = message;
numTokens += role === "user" ? userRoleCount : assistantRoleCount;

View File

@ -35,7 +35,7 @@ type OpenAIChatTokenCountRequest = {
};
type AnthropicChatTokenCountRequest = {
prompt: AnthropicChatMessage[];
prompt: { system: string; messages: AnthropicChatMessage[] };
completion?: never;
service: "anthropic-chat";
};