188 lines
6.1 KiB
TypeScript
188 lines
6.1 KiB
TypeScript
import crypto from "crypto";
|
|
import type * as http from "http";
|
|
import os from "os";
|
|
import schedule from "node-schedule";
|
|
import { config } from "../../config";
|
|
import { logger } from "../../logger";
|
|
import { Key, Model, KeyProvider, LLMService } from "./index";
|
|
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
|
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
|
import { GooglePalmKeyProvider } from "./palm/provider";
|
|
import { AwsBedrockKeyProvider } from "./aws/provider";
|
|
import { ModelFamily } from "../models";
|
|
import { assertNever } from "../utils";
|
|
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
|
|
|
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
|
|
|
export class KeyPool {
|
|
private keyProviders: KeyProvider[] = [];
|
|
private recheckJobs: Partial<Record<LLMService, schedule.Job | null>> = {
|
|
openai: null,
|
|
};
|
|
|
|
constructor() {
|
|
this.keyProviders.push(new OpenAIKeyProvider());
|
|
this.keyProviders.push(new AnthropicKeyProvider());
|
|
this.keyProviders.push(new GooglePalmKeyProvider());
|
|
this.keyProviders.push(new AwsBedrockKeyProvider());
|
|
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
|
}
|
|
|
|
public init() {
|
|
this.keyProviders.forEach((provider) => provider.init());
|
|
const availableKeys = this.available("all");
|
|
if (availableKeys === 0) {
|
|
throw new Error(
|
|
"No keys loaded. Ensure that at least one key is configured."
|
|
);
|
|
}
|
|
this.scheduleRecheck();
|
|
}
|
|
|
|
public get(model: Model): Key {
|
|
const service = this.getServiceForModel(model);
|
|
return this.getKeyProvider(service).get(model);
|
|
}
|
|
|
|
public list(): Omit<Key, "key">[] {
|
|
return this.keyProviders.flatMap((provider) => provider.list());
|
|
}
|
|
|
|
/**
|
|
* Marks a key as disabled for a specific reason. `revoked` should be used
|
|
* to indicate a key that can never be used again, while `quota` should be
|
|
* used to indicate a key that is still valid but has exceeded its quota.
|
|
*/
|
|
public disable(key: Key, reason: "quota" | "revoked"): void {
|
|
const service = this.getKeyProvider(key.service);
|
|
service.disable(key);
|
|
service.update(key.hash, { isRevoked: reason === "revoked" });
|
|
if (service instanceof OpenAIKeyProvider) {
|
|
service.update(key.hash, { isOverQuota: reason === "quota" });
|
|
}
|
|
}
|
|
|
|
public update(key: Key, props: AllowedPartial): void {
|
|
const service = this.getKeyProvider(key.service);
|
|
service.update(key.hash, props);
|
|
}
|
|
|
|
public available(model: Model | "all" = "all"): number {
|
|
return this.keyProviders.reduce((sum, provider) => {
|
|
const includeProvider =
|
|
model === "all" || this.getServiceForModel(model) === provider.service;
|
|
return sum + (includeProvider ? provider.available() : 0);
|
|
}, 0);
|
|
}
|
|
|
|
public incrementUsage(key: Key, model: string, tokens: number): void {
|
|
const provider = this.getKeyProvider(key.service);
|
|
provider.incrementUsage(key.hash, model, tokens);
|
|
}
|
|
|
|
public getLockoutPeriod(family: ModelFamily): number {
|
|
const service = this.getServiceForModelFamily(family);
|
|
return this.getKeyProvider(service).getLockoutPeriod(family);
|
|
}
|
|
|
|
public markRateLimited(key: Key): void {
|
|
const provider = this.getKeyProvider(key.service);
|
|
provider.markRateLimited(key.hash);
|
|
}
|
|
|
|
public updateRateLimits(key: Key, headers: http.IncomingHttpHeaders): void {
|
|
const provider = this.getKeyProvider(key.service);
|
|
if (provider instanceof OpenAIKeyProvider) {
|
|
provider.updateRateLimits(key.hash, headers);
|
|
}
|
|
}
|
|
|
|
public recheck(service: LLMService): void {
|
|
if (!config.checkKeys) {
|
|
logger.info("Skipping key recheck because key checking is disabled");
|
|
return;
|
|
}
|
|
|
|
const provider = this.getKeyProvider(service);
|
|
provider.recheck();
|
|
}
|
|
|
|
private getServiceForModel(model: Model): LLMService {
|
|
if (
|
|
model.startsWith("gpt") ||
|
|
model.startsWith("text-embedding-ada") ||
|
|
model.startsWith("dall-e")
|
|
) {
|
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
|
return "openai";
|
|
} else if (model.startsWith("claude-")) {
|
|
// https://console.anthropic.com/docs/api/reference#parameters
|
|
return "anthropic";
|
|
} else if (model.includes("bison")) {
|
|
// https://developers.generativeai.google.com/models/language
|
|
return "google-palm";
|
|
} else if (model.startsWith("anthropic.claude")) {
|
|
// AWS offers models from a few providers
|
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
|
return "aws";
|
|
} else if (model.startsWith("azure")) {
|
|
return "azure";
|
|
}
|
|
throw new Error(`Unknown service for model '${model}'`);
|
|
}
|
|
|
|
private getServiceForModelFamily(modelFamily: ModelFamily): LLMService {
|
|
switch (modelFamily) {
|
|
case "gpt4":
|
|
case "gpt4-32k":
|
|
case "gpt4-turbo":
|
|
case "turbo":
|
|
case "dall-e":
|
|
return "openai";
|
|
case "claude":
|
|
return "anthropic";
|
|
case "bison":
|
|
return "google-palm";
|
|
case "aws-claude":
|
|
return "aws";
|
|
case "azure-turbo":
|
|
case "azure-gpt4":
|
|
case "azure-gpt4-32k":
|
|
case "azure-gpt4-turbo":
|
|
return "azure";
|
|
default:
|
|
assertNever(modelFamily);
|
|
}
|
|
}
|
|
|
|
private getKeyProvider(service: LLMService): KeyProvider {
|
|
return this.keyProviders.find((provider) => provider.service === service)!;
|
|
}
|
|
|
|
/**
|
|
* Schedules a periodic recheck of OpenAI keys, which runs every 8 hours on
|
|
* a schedule offset by the server's hostname.
|
|
*/
|
|
private scheduleRecheck(): void {
|
|
const machineHash = crypto
|
|
.createHash("sha256")
|
|
.update(os.hostname())
|
|
.digest("hex");
|
|
const offset = parseInt(machineHash, 16) % 7;
|
|
const hour = [0, 8, 16].map((h) => h + offset).join(",");
|
|
const crontab = `0 ${hour} * * *`;
|
|
|
|
const job = schedule.scheduleJob(crontab, () => {
|
|
const next = job.nextInvocation();
|
|
logger.info({ next }, "Performing periodic recheck of OpenAI keys");
|
|
this.recheck("openai");
|
|
});
|
|
logger.info(
|
|
{ rule: crontab, next: job.nextInvocation() },
|
|
"Scheduled periodic key recheck job"
|
|
);
|
|
this.recheckJobs.openai = job;
|
|
}
|
|
}
|