Implement support for Anthropic keys and Claude API (khanon/oai-reverse-proxy!15)

This commit is contained in:
khanon 2023-05-29 17:08:08 +00:00
parent 03aaa6daad
commit 2d93463247
23 changed files with 1530 additions and 656 deletions

View File

@ -10,8 +10,10 @@ export type DequeueMode = "fair" | "random" | "none";
type Config = {
/** The port the proxy server will listen on. */
port: number;
/** OpenAI API key, either a single key or a comma-delimeted list of keys. */
/** Comma-delimited list of OpenAI API keys. */
openaiKey?: string;
/** Comma-delimited list of Anthropic API keys. */
anthropicKey?: string;
/**
* The proxy key to require for requests. Only applicable if the user
* management mode is set to 'proxy_key', and required if so.
@ -118,6 +120,7 @@ type Config = {
export const config: Config = {
port: getEnvWithDefault("PORT", 7860),
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
@ -221,6 +224,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"port",
"logLevel",
"openaiKey",
"anthropicKey",
"proxyKey",
"adminKey",
"checkKeys",
@ -265,7 +269,7 @@ function getEnvWithDefault<T>(name: string, defaultValue: T): T {
return defaultValue;
}
try {
if (name === "OPENAI_KEY") {
if (name === "OPENAI_KEY" || name === "ANTHROPIC_KEY") {
return value as unknown as T;
}
return JSON.parse(value) as T;

View File

@ -28,39 +28,51 @@ function cacheInfoPageHtml(host: string) {
const keys = keyPool.list();
let keyInfo: Record<string, any> = { all: keys.length };
const openAIKeys = keys.filter((k) => k.service === "openai");
const anthropicKeys = keys.filter((k) => k.service === "anthropic");
let anthropicInfo: Record<string, any> = {
all: anthropicKeys.length,
active: anthropicKeys.filter((k) => !k.isDisabled).length,
};
let openAIInfo: Record<string, any> = {
all: openAIKeys.length,
active: openAIKeys.filter((k) => !k.isDisabled).length,
};
if (keyPool.anyUnchecked()) {
const uncheckedKeys = keys.filter((k) => !k.lastChecked);
keyInfo = {
...keyInfo,
openAIInfo = {
...openAIInfo,
active: keys.filter((k) => !k.isDisabled).length,
status: `Still checking ${uncheckedKeys.length} keys...`,
};
} else if (config.checkKeys) {
const trialKeys = keys.filter((k) => k.isTrial);
const turboKeys = keys.filter((k) => !k.isGpt4 && !k.isDisabled);
const gpt4Keys = keys.filter((k) => k.isGpt4 && !k.isDisabled);
const trialKeys = openAIKeys.filter((k) => k.isTrial);
const turboKeys = openAIKeys.filter((k) => !k.isGpt4 && !k.isDisabled);
const gpt4Keys = openAIKeys.filter((k) => k.isGpt4 && !k.isDisabled);
const quota: Record<string, string> = { turbo: "", gpt4: "" };
const hasGpt4 = keys.some((k) => k.isGpt4);
const hasGpt4 = openAIKeys.some((k) => k.isGpt4);
const turboQuota = keyPool.remainingQuota("openai") * 100;
const gpt4Quota = keyPool.remainingQuota("openai", { gpt4: true }) * 100;
if (config.quotaDisplayMode === "full") {
quota.turbo = `${keyPool.usageInUsd()} (${Math.round(
keyPool.remainingQuota() * 100
)}% remaining)`;
quota.gpt4 = `${keyPool.usageInUsd(true)} (${Math.round(
keyPool.remainingQuota(true) * 100
)}% remaining)`;
const turboUsage = keyPool.usageInUsd("openai");
const gpt4Usage = keyPool.usageInUsd("openai", { gpt4: true });
quota.turbo = `${turboUsage} (${Math.round(turboQuota)}% remaining)`;
quota.gpt4 = `${gpt4Usage} (${Math.round(gpt4Quota)}% remaining)`;
} else {
quota.turbo = `${Math.round(keyPool.remainingQuota() * 100)}%`;
quota.gpt4 = `${Math.round(keyPool.remainingQuota(true) * 100)}%`;
quota.turbo = `${Math.round(turboQuota)}%`;
quota.gpt4 = `${Math.round(gpt4Quota * 100)}%`;
}
if (!hasGpt4) {
delete quota.gpt4;
}
keyInfo = {
...keyInfo,
openAIInfo = {
...openAIInfo,
trial: trialKeys.length,
active: {
turbo: turboKeys.length,
@ -70,6 +82,11 @@ function cacheInfoPageHtml(host: string) {
};
}
keyInfo = {
...(openAIKeys.length ? { openai: openAIInfo } : {}),
...(anthropicKeys.length ? { anthropic: anthropicInfo } : {}),
};
const info = {
uptime: process.uptime(),
endpoints: {

View File

@ -0,0 +1,188 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../config";
import { logger } from "../../logger";
export const ANTHROPIC_SUPPORTED_MODELS = [
"claude-instant-v1",
"claude-instant-v1-100k",
"claude-v1",
"claude-v1-100k",
] as const;
export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number];
export interface AnthropicKey extends Key {
readonly service: "anthropic";
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
}
/**
* We don't get rate limit headers from Anthropic so if we get a 429, we just
* lock out the key for 10 seconds.
*/
const RATE_LIMIT_LOCKOUT = 10000;
export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
readonly service = "anthropic";
private keys: AnthropicKey[] = [];
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.anthropicKey?.trim();
if (!keyConfig) {
this.log.warn(
"ANTHROPIC_KEY is not set. Anthropic API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: AnthropicKey = {
key,
service: this.service,
isGpt4: false,
isTrial: false,
isDisabled: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `ant-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded Anthropic keys.");
}
public init() {
// Nothing to do as Anthropic's API doesn't provide any usage information so
// there is no key checker implementation and no need to start it.
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: AnthropicModel) {
// Currently, all Anthropic keys have access to all models. This will almost
// certainly change when they move out of beta later this year.
const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) {
throw new Error("No Anthropic keys available.");
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 2. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
selectedKey.rateLimitedAt = now;
// Intended to throttle the queue processor as otherwise it will just
// flood the API with requests and we want to wait a sec to see if we're
// going to get a rate limit error on this key.
selectedKey.rateLimitedUntil = now + 1000;
return { ...selectedKey };
}
public disable(key: AnthropicKey) {
const keyFromPool = this.keys.find((k) => k.key === key.key);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
// No key checker for Anthropic
public anyUnchecked() {
return false;
}
public incrementPrompt(hash?: string) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
}
public getLockoutPeriod(_model: AnthropicModel) {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
const timeUntilFirstReady = Math.min(
...activeKeys.map((k) => k.rateLimitedUntil - now)
);
return timeUntilFirstReady;
}
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve so all we can do is wait a bit and try
* again.
* We will lock the key for 10 seconds, which should let a few of the other
* generations finish. This is an arbitrary number but the goal is to balance
* between not hammering the API with requests and not locking out a key that
* is actually available.
* TODO; Try to assign requests to slots on each key so we have an idea of how
* long each slot has been running and can make a more informed decision on
* how long to lock the key.
*/
public markRateLimited(keyHash: string) {
this.log.warn({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public remainingQuota() {
const activeKeys = this.keys.filter((k) => !k.isDisabled).length;
const allKeys = this.keys.length;
if (activeKeys === 0) return 0;
return Math.round((activeKeys / allKeys) * 100) / 100;
}
public usageInUsd() {
return "$0.00 / ∞";
}
}

View File

@ -1,5 +1,65 @@
import { OPENAI_SUPPORTED_MODELS, OpenAIModel } from "./openai/provider";
import {
ANTHROPIC_SUPPORTED_MODELS,
AnthropicModel,
} from "./anthropic/provider";
import { KeyPool } from "./key-pool";
export type { Key, Model } from "./key-pool";
export type AIService = "openai" | "anthropic";
export type Model = OpenAIModel | AnthropicModel;
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
readonly key: string;
/** The service that this key is for. */
service: AIService;
/** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */
isTrial: boolean;
/** Whether this key has been provisioned for GPT-4. */
isGpt4: boolean;
/** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */
isDisabled: boolean;
/** The number of prompts that have been sent with this key. */
promptCount: number;
/** The time at which this key was last used. */
lastUsed: number;
/** The time at which this key was last checked. */
lastChecked: number;
/** Hash of the key, for logging and to find the key in the pool. */
hash: string;
}
/*
KeyPool and KeyProvider's similarities are a relic of the old design where
there was only a single KeyPool for OpenAI keys. Now that there are multiple
supported services, the service-specific functionality has been moved to
KeyProvider and KeyPool is just a wrapper around multiple KeyProviders,
delegating to the appropriate one based on the model requested.
Existing code will continue to call methods on KeyPool, which routes them to
the appropriate KeyProvider or returns data aggregated across all KeyProviders
for service-agnostic functionality.
*/
export interface KeyProvider<T extends Key = Key> {
readonly service: AIService;
init(): void;
get(model: Model): T;
list(): Omit<T, "key">[];
disable(key: T): void;
available(): number;
anyUnchecked(): boolean;
incrementPrompt(hash: string): void;
getLockoutPeriod(model: Model): number;
remainingQuota(options?: Record<string, unknown>): number;
usageInUsd(options?: Record<string, unknown>): string;
markRateLimited(hash: string): void;
}
export const keyPool = new KeyPool();
export { SUPPORTED_MODELS } from "./key-pool";
export const SUPPORTED_MODELS = [
...OPENAI_SUPPORTED_MODELS,
...ANTHROPIC_SUPPORTED_MODELS,
] as const;
export type SupportedModel = (typeof SUPPORTED_MODELS)[number];
export { OPENAI_SUPPORTED_MODELS, ANTHROPIC_SUPPORTED_MODELS };

View File

@ -1,378 +1,102 @@
/* Manages OpenAI API keys. Tracks usage, disables expired keys, and provides
round-robin access to keys. Keys are stored in the OPENAI_KEY environment
variable as a comma-separated list of keys. */
import crypto from "crypto";
import fs from "fs";
import http from "http";
import path from "path";
import { config } from "../config";
import { logger } from "../logger";
import { KeyChecker } from "./key-checker";
// TODO: Made too many assumptions about OpenAI being the only provider and now
// this doesn't really work for Anthropic. Create a Provider interface and
// implement Pool, Checker, and Models for each provider.
export type Model = OpenAIModel | AnthropicModel;
export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4";
export type AnthropicModel = "claude-v1" | "claude-instant-v1";
export const SUPPORTED_MODELS: readonly Model[] = [
"gpt-3.5-turbo",
"gpt-4",
"claude-v1",
"claude-instant-v1",
] as const;
export type Key = {
/** The OpenAI API key itself. */
key: string;
/** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */
isTrial: boolean;
/** Whether this key has been provisioned for GPT-4. */
isGpt4: boolean;
/** Whether this key is currently disabled. We set this if we get a 429 or 401 response from OpenAI. */
isDisabled: boolean;
/** Threshold at which a warning email will be sent by OpenAI. */
softLimit: number;
/** Threshold at which the key will be disabled because it has reached the user-defined limit. */
hardLimit: number;
/** The maximum quota allocated to this key by OpenAI. */
systemHardLimit: number;
/** The current usage of this key. */
usage: number;
/** The number of prompts that have been sent with this key. */
promptCount: number;
/** The time at which this key was last used. */
lastUsed: number;
/** The time at which this key was last checked. */
lastChecked: number;
/** Key hash for displaying usage in the dashboard. */
hash: string;
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/**
* Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a
* number.
* Formatted as a `\d+(m|s)` string denoting the time until the limit resets.
* Specifically, it seems to indicate the time until the key's quota will be
* fully restored; the key may be usable before this time as the limit is a
* rolling window.
*
* Requests which return a 429 do not count against the quota.
*
* Requests which fail for other reasons (e.g. 401) count against the quota.
*/
rateLimitRequestsReset: number;
/**
* Last known X-RateLimit-Tokens-Reset header from OpenAI, converted to a
* number.
* Appears to follow the same format as `rateLimitRequestsReset`.
*
* Requests which fail do not count against the quota as they do not consume
* tokens.
*/
rateLimitTokensReset: number;
};
export type KeyUpdate = Omit<
Partial<Key>,
"key" | "hash" | "lastUsed" | "lastChecked" | "promptCount"
>;
import type * as http from "http";
import { AnthropicKeyProvider } from "./anthropic/provider";
import { Key, AIService, Model, KeyProvider } from "./index";
import { OpenAIKeyProvider } from "./openai/provider";
export class KeyPool {
private keys: Key[] = [];
private checker?: KeyChecker;
private log = logger.child({ module: "key-pool" });
private keyProviders: KeyProvider[] = [];
constructor() {
const keyString = config.openaiKey;
if (!keyString?.trim()) {
throw new Error("OPENAI_KEY environment variable is not set");
}
let bareKeys: string[];
bareKeys = keyString.split(",").map((k) => k.trim());
bareKeys = [...new Set(bareKeys)];
for (const k of bareKeys) {
const newKey = {
key: k,
isGpt4: false,
isTrial: false,
isDisabled: false,
softLimit: 0,
hardLimit: 0,
systemHardLimit: 0,
usage: 0,
lastUsed: 0,
lastChecked: 0,
promptCount: 0,
hash: crypto.createHash("sha256").update(k).digest("hex").slice(0, 8),
rateLimitedAt: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded keys");
this.keyProviders.push(new OpenAIKeyProvider());
this.keyProviders.push(new AnthropicKeyProvider());
}
public init() {
if (config.checkKeys) {
this.checker = new KeyChecker(this.keys, this.update.bind(this));
this.checker.start();
}
}
/**
* Returns a list of all keys, with the key field removed.
* Don't mutate returned keys, use a KeyPool method instead.
**/
public list() {
return this.keys.map((key) => {
return Object.freeze({
...key,
key: undefined,
});
});
}
public get(model: Model) {
const needGpt4 = model.startsWith("gpt-4");
const availableKeys = this.keys.filter(
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
this.keyProviders.forEach((provider) => provider.init());
const availableKeys = this.available();
if (availableKeys === 0) {
throw new Error(
"No keys loaded. Ensure either OPENAI_KEY or ANTHROPIC_KEY is set."
);
if (availableKeys.length === 0) {
let message = "No keys available. Please add more keys.";
if (needGpt4) {
message =
"No GPT-4 keys available. Please add more keys or select a non-GPT-4 model.";
}
throw new Error(message);
}
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. We can assume any rate limits over a minute ago are expired
// b. If all keys were rate limited in the last minute, select the
// least recently rate limited key
// 2. Keys which are trials
// 3. Keys which have not been used in the longest time
const now = Date.now();
const rateLimitThreshold = 60 * 1000;
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold;
const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
if (a.isTrial && !b.isTrial) return -1;
if (!a.isTrial && b.isTrial) return 1;
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = Date.now();
// When a key is selected, we rate-limit it for a brief period of time to
// prevent the queue processor from immediately flooding it with requests
// while the initial request is still being processed (which is when we will
// get new rate limit headers).
// Instead, we will let a request through every second until the key
// becomes fully saturated and locked out again.
selectedKey.rateLimitedAt = Date.now();
selectedKey.rateLimitRequestsReset = 1000;
return { ...selectedKey };
}
/** Called by the key checker to update key information. */
public update(keyHash: string, update: KeyUpdate) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
Object.assign(keyFromPool, { ...update, lastChecked: Date.now() });
// this.writeKeyStatus();
}
public disable(key: Key) {
const keyFromPool = this.keys.find((k) => k.key === key.key)!;
if (keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
// If it's disabled just set the usage to the hard limit so it doesn't
// mess with the aggregate usage.
keyFromPool.usage = keyFromPool.hardLimit;
this.log.warn({ key: key.hash }, "Key disabled");
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public anyUnchecked() {
return config.checkKeys && this.keys.some((key) => !key.lastChecked);
}
/**
* Given a model, returns the period until a key will be available to service
* the request, or returns 0 if a key is ready immediately.
*/
public getLockoutPeriod(model: Model = "gpt-4"): number {
const needGpt4 = model.startsWith("gpt-4");
const activeKeys = this.keys.filter(
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
);
if (activeKeys.length === 0) {
// If there are no active keys for this model we can't fulfill requests.
// We'll return 0 to let the request through and return an error,
// otherwise the request will be stuck in the queue forever.
return 0;
}
// A key is rate-limited if its `rateLimitedAt` plus the greater of its
// `rateLimitRequestsReset` and `rateLimitTokensReset` is after the
// current time.
// If there are any keys that are not rate-limited, we can fulfill requests.
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((key) => {
const resetTime = Math.max(
key.rateLimitRequestsReset,
key.rateLimitTokensReset
);
return now < key.rateLimitedAt + resetTime;
}).length;
const anyNotRateLimited = rateLimitedKeys < activeKeys.length;
if (anyNotRateLimited) {
return 0;
}
// If all keys are rate-limited, return the time until the first key is
// ready.
const timeUntilFirstReady = Math.min(
...activeKeys.map((key) => {
const resetTime = Math.max(
key.rateLimitRequestsReset,
key.rateLimitTokensReset
);
return key.rateLimitedAt + resetTime - now;
})
);
return timeUntilFirstReady;
}
public markRateLimited(keyHash: string) {
this.log.warn({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
key.rateLimitedAt = Date.now();
}
public incrementPrompt(keyHash?: string) {
if (!keyHash) return;
const key = this.keys.find((k) => k.hash === keyHash)!;
key.promptCount++;
}
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) {
const key = this.keys.find((k) => k.hash === keyHash)!;
const requestsReset = headers["x-ratelimit-reset-requests"];
const tokensReset = headers["x-ratelimit-reset-tokens"];
// Sometimes OpenAI only sends one of the two rate limit headers, it's
// unclear why.
if (requestsReset && typeof requestsReset === "string") {
this.log.info(
{ key: key.hash, requestsReset },
`Updating rate limit requests reset time`
);
key.rateLimitRequestsReset = getResetDurationMillis(requestsReset);
}
if (tokensReset && typeof tokensReset === "string") {
this.log.info(
{ key: key.hash, tokensReset },
`Updating rate limit tokens reset time`
);
key.rateLimitTokensReset = getResetDurationMillis(tokensReset);
}
if (!requestsReset && !tokensReset) {
this.log.warn(
{ key: key.hash },
`No rate limit headers in OpenAI response; skipping update`
);
return;
}
}
/** Returns the remaining aggregate quota for all keys as a percentage. */
public remainingQuota(gpt4 = false) {
const keys = this.keys.filter((k) => k.isGpt4 === gpt4);
if (keys.length === 0) return 0;
const totalUsage = keys.reduce((acc, key) => {
// Keys can slightly exceed their quota
return acc + Math.min(key.usage, key.hardLimit);
}, 0);
const totalLimit = keys.reduce((acc, { hardLimit }) => acc + hardLimit, 0);
return 1 - totalUsage / totalLimit;
public get(model: Model): Key {
const service = this.getService(model);
return this.getKeyProvider(service).get(model);
}
/** Returns used and available usage in USD. */
public usageInUsd(gpt4 = false) {
const keys = this.keys.filter((k) => k.isGpt4 === gpt4);
if (keys.length === 0) return "???";
public list(): Omit<Key, "key">[] {
return this.keyProviders.flatMap((provider) => provider.list());
}
const totalHardLimit = keys.reduce(
(acc, { hardLimit }) => acc + hardLimit,
public disable(key: Key): void {
const service = this.getKeyProvider(key.service);
service.disable(key);
}
// TODO: this probably needs to be scoped to a specific provider. I think the
// only code calling this is the error handler which needs to know how many
// more keys are available for the provider the user tried to use.
public available(): number {
return this.keyProviders.reduce(
(sum, provider) => sum + provider.available(),
0
);
const totalUsage = keys.reduce((acc, key) => {
// Keys can slightly exceed their quota
return acc + Math.min(key.usage, key.hardLimit);
}, 0);
return `$${totalUsage.toFixed(2)} / $${totalHardLimit.toFixed(2)}`;
}
/** Writes key status to disk. */
// public writeKeyStatus() {
// const keys = this.keys.map((key) => ({
// key: key.key,
// isGpt4: key.isGpt4,
// usage: key.usage,
// hardLimit: key.hardLimit,
// isDisabled: key.isDisabled,
// }));
// fs.writeFileSync(
// path.join(__dirname, "..", "keys.json"),
// JSON.stringify(keys, null, 2)
// );
// }
}
/**
* Converts reset string ("21.0032s" or "21ms") to a number of milliseconds.
* Result is clamped to 10s even though the API returns up to 60s, because the
* API returns the time until the entire quota is reset, even if a key may be
* able to fulfill requests before then due to partial resets.
**/
function getResetDurationMillis(resetDuration?: string): number {
const match = resetDuration?.match(/(\d+(\.\d+)?)(s|ms)/);
if (match) {
const [, time, , unit] = match;
const value = parseFloat(time);
const result = unit === "s" ? value * 1000 : value;
return Math.min(result, 10000);
public anyUnchecked(): boolean {
return this.keyProviders.some((provider) => provider.anyUnchecked());
}
public incrementPrompt(key: Key): void {
const provider = this.getKeyProvider(key.service);
provider.incrementPrompt(key.hash);
}
public getLockoutPeriod(model: Model): number {
const service = this.getService(model);
return this.getKeyProvider(service).getLockoutPeriod(model);
}
public markRateLimited(key: Key): void {
const provider = this.getKeyProvider(key.service);
provider.markRateLimited(key.hash);
}
public updateRateLimits(key: Key, headers: http.IncomingHttpHeaders): void {
const provider = this.getKeyProvider(key.service);
if (provider instanceof OpenAIKeyProvider) {
provider.updateRateLimits(key.hash, headers);
}
}
public remainingQuota(
service: AIService,
options?: Record<string, unknown>
): number {
return this.getKeyProvider(service).remainingQuota(options);
}
public usageInUsd(
service: AIService,
options?: Record<string, unknown>
): string {
return this.getKeyProvider(service).usageInUsd(options);
}
private getService(model: Model): AIService {
if (model.startsWith("gpt")) {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
return "openai";
} else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters
return "anthropic";
}
throw new Error(`Unknown service for model '${model}'`);
}
private getKeyProvider(service: AIService): KeyProvider {
return this.keyProviders.find((provider) => provider.service === service)!;
}
return 0;
}

View File

@ -1,7 +1,7 @@
import axios, { AxiosError } from "axios";
import { Configuration, OpenAIApi } from "openai";
import { logger } from "../logger";
import type { Key, KeyPool } from "./key-pool";
import { logger } from "../../logger";
import type { OpenAIKey, OpenAIKeyProvider } from "./provider";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 5 * 60 * 1000; // 5 minutes
@ -26,16 +26,16 @@ type OpenAIError = {
error: { type: string; code: string; param: unknown; message: string };
};
type UpdateFn = typeof KeyPool.prototype.update;
type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
export class KeyChecker {
private readonly keys: Key[];
private log = logger.child({ module: "key-checker" });
export class OpenAIKeyChecker {
private readonly keys: OpenAIKey[];
private log = logger.child({ module: "key-checker", service: "openai" });
private timeout?: NodeJS.Timeout;
private updateKey: UpdateFn;
private lastCheck = 0;
constructor(keys: Key[], updateKey: UpdateFn) {
constructor(keys: OpenAIKey[], updateKey: UpdateFn) {
this.keys = keys;
this.updateKey = updateKey;
}
@ -110,7 +110,7 @@ export class KeyChecker {
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
}
private async checkKey(key: Key) {
private async checkKey(key: OpenAIKey) {
// It's possible this key might have been disabled while we were waiting
// for the next check.
if (key.isDisabled) {
@ -180,7 +180,7 @@ export class KeyChecker {
}
private async getProvisionedModels(
key: Key
key: OpenAIKey
): Promise<{ turbo: boolean; gpt4: boolean }> {
const openai = new OpenAIApi(new Configuration({ apiKey: key.key }));
const models = (await openai.listModels()!).data.data;
@ -189,7 +189,7 @@ export class KeyChecker {
return { turbo, gpt4 };
}
private async getSubscription(key: Key) {
private async getSubscription(key: OpenAIKey) {
const { data } = await axios.get<GetSubscriptionResponse>(
GET_SUBSCRIPTION_URL,
{ headers: { Authorization: `Bearer ${key.key}` } }
@ -197,8 +197,8 @@ export class KeyChecker {
return data;
}
private async getUsage(key: Key) {
const querystring = KeyChecker.getUsageQuerystring(key.isTrial);
private async getUsage(key: OpenAIKey) {
const querystring = OpenAIKeyChecker.getUsageQuerystring(key.isTrial);
const url = `${GET_USAGE_URL}?${querystring}`;
const { data } = await axios.get<GetUsageResponse>(url, {
headers: { Authorization: `Bearer ${key.key}` },
@ -206,8 +206,8 @@ export class KeyChecker {
return parseFloat((data.total_usage / 100).toFixed(2));
}
private handleAxiosError(key: Key, error: AxiosError) {
if (error.response && KeyChecker.errorIsOpenAiError(error)) {
private handleAxiosError(key: OpenAIKey, error: AxiosError) {
if (error.response && OpenAIKeyChecker.errorIsOpenAiError(error)) {
const { status, data } = error.response;
if (status === 401) {
this.log.warn(
@ -239,7 +239,7 @@ export class KeyChecker {
* Trial key usage reporting is inaccurate, so we need to run an actual
* completion to test them for liveness.
*/
private async assertCanGenerate(key: Key): Promise<void> {
private async assertCanGenerate(key: OpenAIKey): Promise<void> {
const openai = new OpenAIApi(new Configuration({ apiKey: key.key }));
// This will throw an AxiosError if the key is invalid or out of quota.
await openai.createChatCompletion({

View File

@ -0,0 +1,360 @@
/* Manages OpenAI API keys. Tracks usage, disables expired keys, and provides
round-robin access to keys. Keys are stored in the OPENAI_KEY environment
variable as a comma-separated list of keys. */
import crypto from "crypto";
import fs from "fs";
import http from "http";
import path from "path";
import { KeyProvider, Key, Model } from "../index";
import { config } from "../../config";
import { logger } from "../../logger";
import { OpenAIKeyChecker } from "./checker";
export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4";
export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
"gpt-3.5-turbo",
"gpt-4",
] as const;
export interface OpenAIKey extends Key {
readonly service: "openai";
/** The current usage of this key. */
usage: number;
/** Threshold at which a warning email will be sent by OpenAI. */
softLimit: number;
/** Threshold at which the key will be disabled because it has reached the user-defined limit. */
hardLimit: number;
/** The maximum quota allocated to this key by OpenAI. */
systemHardLimit: number;
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/**
* Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a
* number.
* Formatted as a `\d+(m|s)` string denoting the time until the limit resets.
* Specifically, it seems to indicate the time until the key's quota will be
* fully restored; the key may be usable before this time as the limit is a
* rolling window.
*
* Requests which return a 429 do not count against the quota.
*
* Requests which fail for other reasons (e.g. 401) count against the quota.
*/
rateLimitRequestsReset: number;
/**
* Last known X-RateLimit-Tokens-Reset header from OpenAI, converted to a
* number.
* Appears to follow the same format as `rateLimitRequestsReset`.
*
* Requests which fail do not count against the quota as they do not consume
* tokens.
*/
rateLimitTokensReset: number;
}
export type OpenAIKeyUpdate = Omit<
Partial<OpenAIKey>,
"key" | "hash" | "lastUsed" | "lastChecked" | "promptCount"
>;
export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
readonly service = "openai" as const;
private keys: OpenAIKey[] = [];
private checker?: OpenAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyString = config.openaiKey?.trim();
if (!keyString) {
this.log.warn("OPENAI_KEY is not set. OpenAI API will not be available.");
return;
}
let bareKeys: string[];
bareKeys = keyString.split(",").map((k) => k.trim());
bareKeys = [...new Set(bareKeys)];
for (const k of bareKeys) {
const newKey = {
key: k,
service: "openai" as const,
isGpt4: false,
isTrial: false,
isDisabled: false,
softLimit: 0,
hardLimit: 0,
systemHardLimit: 0,
usage: 0,
lastUsed: 0,
lastChecked: 0,
promptCount: 0,
hash: `oai-${crypto
.createHash("sha256")
.update(k)
.digest("hex")
.slice(0, 8)}`,
rateLimitedAt: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded OpenAI keys.");
}
public init() {
if (config.checkKeys) {
this.checker = new OpenAIKeyChecker(this.keys, this.update.bind(this));
this.checker.start();
}
}
/**
* Returns a list of all keys, with the key field removed.
* Don't mutate returned keys, use a KeyPool method instead.
**/
public list() {
return this.keys.map((key) => {
return Object.freeze({
...key,
key: undefined,
});
});
}
public get(model: Model) {
const needGpt4 = model.startsWith("gpt-4");
const availableKeys = this.keys.filter(
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
);
if (availableKeys.length === 0) {
let message = needGpt4
? "No active OpenAI keys available."
: "No GPT-4 keys available. Try selecting a non-GPT-4 model.";
throw new Error(message);
}
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. We ignore rate limits from over a minute ago
// b. If all keys were rate limited in the last minute, select the
// least recently rate limited key
// 2. Keys which are trials
// 3. Keys which have not been used in the longest time
const now = Date.now();
const rateLimitThreshold = 60 * 1000;
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold;
const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
if (a.isTrial && !b.isTrial) return -1;
if (!a.isTrial && b.isTrial) return 1;
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
// When a key is selected, we rate-limit it for a brief period of time to
// prevent the queue processor from immediately flooding it with requests
// while the initial request is still being processed (which is when we will
// get new rate limit headers).
// Instead, we will let a request through every second until the key
// becomes fully saturated and locked out again.
selectedKey.rateLimitedAt = now;
selectedKey.rateLimitRequestsReset = 1000;
return { ...selectedKey };
}
/** Called by the key checker to update key information. */
public update(keyHash: string, update: OpenAIKeyUpdate) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
Object.assign(keyFromPool, { ...update, lastChecked: Date.now() });
// this.writeKeyStatus();
}
/** Disables a key, or does nothing if the key isn't in this pool. */
public disable(key: Key) {
const keyFromPool = this.keys.find((k) => k.key === key.key);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
// If it's disabled just set the usage to the hard limit so it doesn't
// mess with the aggregate usage.
keyFromPool.usage = keyFromPool.hardLimit;
this.log.warn({ key: key.hash }, "Key disabled");
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public anyUnchecked() {
return !!config.checkKeys && this.keys.some((key) => !key.lastChecked);
}
/**
* Given a model, returns the period until a key will be available to service
* the request, or returns 0 if a key is ready immediately.
*/
public getLockoutPeriod(model: Model = "gpt-4"): number {
const needGpt4 = model.startsWith("gpt-4");
const activeKeys = this.keys.filter(
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
);
if (activeKeys.length === 0) {
// If there are no active keys for this model we can't fulfill requests.
// We'll return 0 to let the request through and return an error,
// otherwise the request will be stuck in the queue forever.
return 0;
}
// A key is rate-limited if its `rateLimitedAt` plus the greater of its
// `rateLimitRequestsReset` and `rateLimitTokensReset` is after the
// current time.
// If there are any keys that are not rate-limited, we can fulfill requests.
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((key) => {
const resetTime = Math.max(
key.rateLimitRequestsReset,
key.rateLimitTokensReset
);
return now < key.rateLimitedAt + resetTime;
}).length;
const anyNotRateLimited = rateLimitedKeys < activeKeys.length;
if (anyNotRateLimited) {
return 0;
}
// If all keys are rate-limited, return the time until the first key is
// ready.
const timeUntilFirstReady = Math.min(
...activeKeys.map((key) => {
const resetTime = Math.max(
key.rateLimitRequestsReset,
key.rateLimitTokensReset
);
return key.rateLimitedAt + resetTime - now;
})
);
return timeUntilFirstReady;
}
public markRateLimited(keyHash: string) {
this.log.warn({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
key.rateLimitedAt = Date.now();
}
public incrementPrompt(keyHash?: string) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
}
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) {
const key = this.keys.find((k) => k.hash === keyHash)!;
const requestsReset = headers["x-ratelimit-reset-requests"];
const tokensReset = headers["x-ratelimit-reset-tokens"];
// Sometimes OpenAI only sends one of the two rate limit headers, it's
// unclear why.
if (requestsReset && typeof requestsReset === "string") {
this.log.info(
{ key: key.hash, requestsReset },
`Updating rate limit requests reset time`
);
key.rateLimitRequestsReset = getResetDurationMillis(requestsReset);
}
if (tokensReset && typeof tokensReset === "string") {
this.log.info(
{ key: key.hash, tokensReset },
`Updating rate limit tokens reset time`
);
key.rateLimitTokensReset = getResetDurationMillis(tokensReset);
}
if (!requestsReset && !tokensReset) {
this.log.warn(
{ key: key.hash },
`No rate limit headers in OpenAI response; skipping update`
);
return;
}
}
/** Returns the remaining aggregate quota for all keys as a percentage. */
public remainingQuota({ gpt4 }: { gpt4: boolean } = { gpt4: false }): number {
const keys = this.keys.filter((k) => k.isGpt4 === gpt4);
if (keys.length === 0) return 0;
const totalUsage = keys.reduce((acc, key) => {
// Keys can slightly exceed their quota
return acc + Math.min(key.usage, key.hardLimit);
}, 0);
const totalLimit = keys.reduce((acc, { hardLimit }) => acc + hardLimit, 0);
return 1 - totalUsage / totalLimit;
}
/** Returns used and available usage in USD. */
public usageInUsd({ gpt4 }: { gpt4: boolean } = { gpt4: false }): string {
const keys = this.keys.filter((k) => k.isGpt4 === gpt4);
if (keys.length === 0) return "???";
const totalHardLimit = keys.reduce(
(acc, { hardLimit }) => acc + hardLimit,
0
);
const totalUsage = keys.reduce((acc, key) => {
// Keys can slightly exceed their quota
return acc + Math.min(key.usage, key.hardLimit);
}, 0);
return `$${totalUsage.toFixed(2)} / $${totalHardLimit.toFixed(2)}`;
}
/** Writes key status to disk. */
// public writeKeyStatus() {
// const keys = this.keys.map((key) => ({
// key: key.key,
// isGpt4: key.isGpt4,
// usage: key.usage,
// hardLimit: key.hardLimit,
// isDisabled: key.isDisabled,
// }));
// fs.writeFileSync(
// path.join(__dirname, "..", "keys.json"),
// JSON.stringify(keys, null, 2)
// );
// }
}
/**
* Converts reset string ("21.0032s" or "21ms") to a number of milliseconds.
* Result is clamped to 10s even though the API returns up to 60s, because the
* API returns the time until the entire quota is reset, even if a key may be
* able to fulfill requests before then due to partial resets.
**/
function getResetDurationMillis(resetDuration?: string): number {
const match = resetDuration?.match(/(\d+(\.\d+)?)(s|ms)/);
if (match) {
const [, time, , unit] = match;
const value = parseFloat(time);
const result = unit === "s" ? value * 1000 : value;
return Math.min(result, 10000);
}
return 0;
}

171
src/proxy/anthropic.ts Normal file
View File

@ -0,0 +1,171 @@
import { Request, Router } from "express";
import * as http from "http";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { logger } from "../logger";
import {
addKey,
finalizeBody,
languageFilter,
limitOutputTokens,
transformOutboundPayload,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
handleInternalError,
} from "./middleware/response";
import { createQueueMiddleware } from "./queue";
const rewriteAnthropicRequest = (
proxyReq: http.ClientRequest,
req: Request,
res: http.ServerResponse
) => {
req.api = "anthropic";
const rewriterPipeline = [
addKey,
languageFilter,
limitOutputTokens,
transformOutboundPayload,
finalizeBody,
];
try {
for (const rewriter of rewriterPipeline) {
rewriter(proxyReq, req, res, {});
}
} catch (error) {
req.log.error(error, "Error while executing proxy rewriter");
proxyReq.destroy(error as Error);
}
};
/** Only used for non-streaming requests. */
const anthropicResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (!req.originalUrl.includes("/v1/complete")) {
req.log.info("Transforming Anthropic response to OpenAI format");
body = transformAnthropicResponse(body);
}
res.status(200).json(body);
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformAnthropicResponse(
anthropicBody: Record<string, any>
): Record<string, any> {
return {
id: "ant-" + anthropicBody.log_id,
object: "chat.completion",
created: Date.now(),
model: anthropicBody.model,
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
choices: [
{
message: {
role: "assistant",
content: anthropicBody.completion?.trim(),
},
finish_reason: anthropicBody.stop_reason,
index: 0,
},
],
};
}
const anthropicProxy = createProxyMiddleware({
target: "https://api.anthropic.com",
changeOrigin: true,
on: {
proxyReq: rewriteAnthropicRequest,
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
error: handleInternalError,
},
selfHandleResponse: true,
logger,
pathRewrite: {
// If the user sends a request to /v1/chat/completions (the OpenAI endpoint)
// we will transform the payload and rewrite the path to /v1/complete.
"^/v1/chat/completions": "/v1/complete",
},
});
const queuedAnthropicProxy = createQueueMiddleware(anthropicProxy);
const anthropicRouter = Router();
anthropicRouter.use((req, _res, next) => {
if (!req.path.startsWith("/v1/")) {
req.url = `/v1${req.url}`;
}
next();
});
anthropicRouter.get("/v1/models", (req, res) => {
res.json(buildFakeModelsResponse());
});
anthropicRouter.post("/v1/complete", queuedAnthropicProxy);
// This is the OpenAI endpoint, to let users send OpenAI-formatted requests
// to the Anthropic API. We need to rewrite them first.
anthropicRouter.post("/v1/chat/completions", queuedAnthropicProxy);
// Redirect browser requests to the homepage.
anthropicRouter.get("*", (req, res, next) => {
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
if (isBrowser) {
res.redirect("/");
} else {
next();
}
});
function buildFakeModelsResponse() {
const claudeVariants = [
"claude-v1",
"claude-v1-100k",
"claude-instant-v1",
"claude-instant-v1-100k",
"claude-v1.3",
"claude-v1.3-100k",
"claude-v1.2",
"claude-v1.0",
"claude-instant-v1.1",
"claude-instant-v1.1-100k",
"claude-instant-v1.0",
];
const models = claudeVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "anthropic",
permission: [],
root: "claude",
parent: null,
}));
return {
models,
};
}
export const anthropic = anthropicRouter;

View File

@ -9,7 +9,6 @@ import { logger } from "../logger";
import { ipLimiter } from "./rate-limit";
import {
addKey,
checkStreaming,
finalizeBody,
languageFilter,
limitOutputTokens,
@ -41,11 +40,11 @@ const rewriteRequest = (
}
req.api = "kobold";
req.body.stream = false;
const rewriterPipeline = [
addKey,
transformKoboldPayload,
languageFilter,
checkStreaming,
limitOutputTokens,
finalizeBody,
];

View File

@ -1,45 +1,52 @@
import { Key, Model, keyPool, SUPPORTED_MODELS } from "../../../key-management";
import { Key, keyPool } from "../../../key-management";
import type { ExpressHttpProxyReqCallback } from ".";
/** Add an OpenAI key from the pool to the request. */
/** Add a key that can service this request to the request object. */
export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
let assignedKey: Key;
// Not all clients request a particular model.
// If they request a model, just use that.
// If they don't request a model, use a GPT-4 key if there is an active one,
// otherwise use a GPT-3.5 key.
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
// TODO: Anthropic mode should prioritize Claude over Claude Instant.
// Each provider needs to define some priority order for their models.
// This should happen somewhere else but addKey is guaranteed to run first.
req.isStreaming = req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming;
if (bodyHasModel(req.body)) {
assignedKey = keyPool.get(req.body.model);
// Anthropic support has a special endpoint that accepts OpenAI-formatted
// requests and translates them into Anthropic requests. On this endpoint,
// the requested model is an OpenAI one even though we're actually sending
// an Anthropic request.
// For such cases, ignore the requested model entirely.
// Real Anthropic requests come in via /proxy/anthropic/v1/complete
// The OpenAI-compatible endpoint is /proxy/anthropic/v1/chat/completions
const openaiCompatible =
req.originalUrl === "/proxy/anthropic/v1/chat/completions";
if (openaiCompatible) {
req.log.debug("Using an Anthropic key for an OpenAI-compatible request");
req.api = "openai";
// We don't assign the model here, that will happen when transforming the
// request body.
assignedKey = keyPool.get("claude-v1");
} else {
try {
assignedKey = keyPool.get("gpt-4");
} catch {
assignedKey = keyPool.get("gpt-3.5-turbo");
}
assignedKey = keyPool.get(req.body.model);
}
req.key = assignedKey;
req.log.info(
{
key: assignedKey.hash,
model: req.body?.model,
isGpt4: assignedKey.isGpt4,
fromApi: req.api,
toApi: assignedKey.service,
},
"Assigned key to request"
);
// TODO: Requests to Anthropic models use `X-API-Key`.
if (assignedKey.service === "anthropic") {
proxyReq.setHeader("X-API-Key", assignedKey.key);
} else {
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
}
};
function bodyHasModel(body: any): body is { model: Model } {
// Model names can have suffixes indicating the frozen release version but
// OpenAI and Anthropic will use the latest version if you omit the suffix.
const isSupportedModel = (model: string) =>
SUPPORTED_MODELS.some((supported) => model.startsWith(supported));
return typeof body?.model === "string" && isSupportedModel(body.model);
}

View File

@ -1,24 +0,0 @@
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
/**
* If a stream is requested, mark the request as such so the response middleware
* knows to use the alternate EventSource response handler.
* Kobold requests can't currently be streamed as they use a different event
* format than the OpenAI API and we need to rewrite the events as they come in,
* which I have not yet implemented.
*/
export const checkStreaming: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
const streamableApi = req.api !== "kobold";
if (isCompletionRequest(req) && req.body?.stream) {
if (!streamableApi) {
req.log.warn(
{ api: req.api, key: req.key?.hash },
`Streaming requested, but ${req.api} streaming is not supported.`
);
req.body.stream = false;
return;
}
req.body.stream = true;
req.isStreaming = true;
}
};

View File

@ -3,20 +3,23 @@ import type { ClientRequest } from "http";
import type { ProxyReqCallback } from "http-proxy";
export { addKey } from "./add-key";
export { checkStreaming } from "./check-streaming";
export { finalizeBody } from "./finalize-body";
export { languageFilter } from "./language-filter";
export { limitCompletions } from "./limit-completions";
export { limitOutputTokens } from "./limit-output-tokens";
export { transformKoboldPayload } from "./transform-kobold-payload";
export { transformOutboundPayload } from "./transform-outbound-payload";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
/** Returns true if we're making a chat completion request. */
/** Returns true if we're making a request to a completion endpoint. */
export function isCompletionRequest(req: Request) {
return (
req.method === "POST" &&
req.path.startsWith(OPENAI_CHAT_COMPLETION_ENDPOINT)
[OPENAI_CHAT_COMPLETION_ENDPOINT, ANTHROPIC_COMPLETION_ENDPOINT].some(
(endpoint) => req.path.startsWith(endpoint)
)
);
}

View File

@ -1,6 +1,9 @@
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
/** Don't allow multiple completions to be requested to prevent abuse. */
/**
* Don't allow multiple completions to be requested to prevent abuse.
* OpenAI-only, Anthropic provides no such parameter.
**/
export const limitCompletions: ExpressHttpProxyReqCallback = (
_proxyReq,
req

View File

@ -1,29 +1,43 @@
import { Request } from "express";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
const MAX_TOKENS = config.maxOutputTokens;
/** Enforce a maximum number of tokens requested from OpenAI. */
/** Enforce a maximum number of tokens requested from the model. */
export const limitOutputTokens: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (isCompletionRequest(req) && req.body?.max_tokens) {
// convert bad or missing input to a MAX_TOKENS
if (typeof req.body.max_tokens !== "number") {
logger.warn(
`Invalid max_tokens value: ${req.body.max_tokens}. Using ${MAX_TOKENS}`
const requestedMaxTokens = getMaxTokensFromRequest(req);
let maxTokens = requestedMaxTokens;
if (typeof requestedMaxTokens !== "number") {
req.log.warn(
{ requestedMaxTokens, clampedMaxTokens: MAX_TOKENS },
"Invalid max tokens value. Using default value."
);
req.body.max_tokens = MAX_TOKENS;
maxTokens = MAX_TOKENS;
}
const originalTokens = req.body.max_tokens;
req.body.max_tokens = Math.min(req.body.max_tokens, MAX_TOKENS);
if (originalTokens !== req.body.max_tokens) {
logger.warn(
`Limiting max_tokens from ${originalTokens} to ${req.body.max_tokens}`
// TODO: this is not going to scale well, need to implement a better way
// of translating request parameters from one API to another.
maxTokens = Math.min(maxTokens, MAX_TOKENS);
if (req.key!.service === "openai") {
req.body.max_tokens = maxTokens;
} else if (req.key!.service === "anthropic") {
req.body.max_tokens_to_sample = maxTokens;
}
if (requestedMaxTokens !== maxTokens) {
req.log.warn(
`Limiting max tokens from ${requestedMaxTokens} to ${maxTokens}`
);
}
}
};
function getMaxTokensFromRequest(req: Request) {
return (req.body?.max_tokens || req.body?.max_tokens_to_sample) ?? MAX_TOKENS;
}

View File

@ -1,3 +1,8 @@
/**
* Transforms a KoboldAI payload into an OpenAI payload.
* @deprecated Kobold input format isn't supported anymore as all popular
* frontends support reverse proxies or changing their base URL.
*/
import { logger } from "../../../logger";
import type { ExpressHttpProxyReqCallback } from ".";
@ -63,6 +68,10 @@ export const transformKoboldPayload: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (req.api !== "kobold") {
throw new Error("transformKoboldPayload called for non-kobold request.");
}
const { body } = req;
const { prompt, max_length, rep_pen, top_p, temperature } = body;

View File

@ -0,0 +1,125 @@
import { Request } from "express";
import { z } from "zod";
import type { ExpressHttpProxyReqCallback } from ".";
// https://console.anthropic.com/docs/api/reference#-v1-complete
const AnthropicV1CompleteSchema = z.object({
model: z.string().regex(/^claude-/),
prompt: z.string(),
max_tokens_to_sample: z.number(),
stop_sequences: z.array(z.string()).optional(),
stream: z.boolean().optional().default(false),
temperature: z.number().optional().default(1),
top_k: z.number().optional().default(-1),
top_p: z.number().optional().default(-1),
metadata: z.any().optional(),
});
// https://platform.openai.com/docs/api-reference/chat/create
const OpenAIV1ChatCompletionSchema = z.object({
model: z.string().regex(/^gpt/),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
name: z.string().optional(),
})
),
temperature: z.number().optional().default(1),
top_p: z.number().optional().default(1),
n: z.literal(1).optional(),
stream: z.boolean().optional().default(false),
stop: z.union([z.string(), z.array(z.string())]).optional(),
max_tokens: z.number().optional(),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
logit_bias: z.any().optional(),
user: z.string().optional(),
});
/** Transforms an incoming request body to one that matches the target API. */
export const transformOutboundPayload: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (req.retryCount > 0) {
// We've already transformed the payload once, so don't do it again.
return;
}
const inboundService = req.api;
const outboundService = req.key!.service;
if (inboundService === outboundService) {
return;
}
// Not supported yet and unnecessary as everything supports OpenAI.
if (inboundService === "anthropic" && outboundService === "openai") {
throw new Error(
"Anthropic -> OpenAI request transformation not supported. Provide an OpenAI-compatible payload, or use the /claude endpoint."
);
}
if (inboundService === "openai" && outboundService === "anthropic") {
req.body = openaiToAnthropic(req.body, req);
return;
}
throw new Error(
`Unsupported transformation: ${inboundService} -> ${outboundService}`
);
};
function openaiToAnthropic(body: any, req: Request) {
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
// don't log the prompt
const { messages, ...params } = body;
req.log.error(
{ issues: result.error.issues, params },
"Invalid OpenAI-to-Anthropic request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt =
result.data.messages
.map((m) => {
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "Human";
}
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
m.content
}`;
})
.join("") + "\n\nAssistant: ";
// When translating from OpenAI to Anthropic, we obviously can't use the
// provided OpenAI model name as-is. We will instead select a Claude model,
// choosing either the 100k token model or the 9k token model depending on
// the length of the prompt. I'm not bringing in the full OpenAI tokenizer for
// this so we'll use Anthropic's guideline of ~28000 characters to about 8k
// tokens (https://console.anthropic.com/docs/prompt-design#prompt-length)
// as the cutoff, minus a little bit for safety.
// For smaller prompts we use 1.1 because it's less cucked.
// For big prompts (v1, auto-selects the latest model) is all we can use.
const model = prompt.length > 25000 ? "claude-v1-100k" : "claude-v1.1";
return {
...rest,
model,
prompt,
max_tokens_to_sample: rest.max_tokens,
stop_sequences: rest.stop,
};
}

View File

@ -1,6 +1,29 @@
import { Response } from "express";
import { Request, Response } from "express";
import * as http from "http";
import { RawResponseBodyHandler, decodeResponseBody } from ".";
import { buildFakeSseMessage } from "../../queue";
type OpenAiChatCompletionResponse = {
id: string;
object: string;
created: number;
model: string;
choices: {
message: { role: string; content: string };
finish_reason: string | null;
index: number;
}[];
};
type AnthropicCompletionResponse = {
completion: string;
stop_reason: string;
truncated: boolean;
stop: any;
model: string;
log_id: string;
exception: null;
};
/**
* Consume the SSE stream and forward events to the client. Once the stream is
@ -11,18 +34,28 @@ import { RawResponseBodyHandler, decodeResponseBody } from ".";
* in the event a streamed request results in a non-200 response, we need to
* fall back to the non-streaming response handler so that the error handler
* can inspect the error response.
*
* Currently most frontends don't support Anthropic streaming, so users can opt
* to send requests for Claude models via an endpoint that accepts OpenAI-
* compatible requests and translates the received Anthropic SSE events into
* OpenAI ones, essentially pretending to be an OpenAI streaming API.
*/
export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes,
req,
res
) => {
// If these differ, the user is using the OpenAI-compatibile endpoint, so
// we need to translate the SSE events into OpenAI completion events for their
// frontend.
const fromApi = req.api;
const toApi = req.key!.service;
if (!req.isStreaming) {
req.log.error(
{ api: req.api, key: req.key?.hash },
`handleEventSource called for non-streaming request, which isn't valid.`
`handleStreamedResponse called for non-streaming request, which isn't valid.`
);
throw new Error("handleEventSource called for non-streaming request.");
throw new Error("handleStreamedResponse called for non-streaming request.");
}
if (proxyRes.statusCode !== 200) {
@ -53,42 +86,81 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
res.flushHeaders();
}
const chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => {
chunks.push(chunk);
res.write(chunk);
});
const fullChunks: string[] = [];
let chunkBuffer: string[] = [];
let messageBuffer = "";
let lastPosition = 0;
proxyRes.on("end", () => {
const finalBody = convertEventsToOpenAiResponse(chunks);
type ProxyResHandler<T extends unknown> = (...args: T[]) => void;
function withErrorHandling<T extends unknown>(fn: ProxyResHandler<T>) {
return (...args: T[]) => {
try {
fn(...args);
} catch (error) {
proxyRes.emit("error", error);
}
};
}
proxyRes.on(
"data",
withErrorHandling((chunk) => {
// We may receive multiple (or partial) SSE messages in a single chunk, so
// we need to buffer and emit seperate stream events for full messages so
// we can parse/transform them properly.
const str = chunk.toString();
chunkBuffer.push(str);
const newMessages = (messageBuffer + chunkBuffer.join("")).split(
/\r?\n\r?\n/ // Anthropic uses CRLF line endings (out-of-spec btw)
);
chunkBuffer = [];
messageBuffer = newMessages.pop() || "";
for (const message of newMessages) {
proxyRes.emit("full-sse-event", message);
}
})
);
proxyRes.on(
"full-sse-event",
withErrorHandling((data) => {
const { event, position } = transformEvent(
data,
fromApi,
toApi,
lastPosition
);
fullChunks.push(event);
lastPosition = position;
res.write(event + "\n\n");
})
);
proxyRes.on(
"end",
withErrorHandling(() => {
let finalBody = convertEventsToFinalResponse(fullChunks, req);
req.log.info(
{ api: req.api, key: req.key?.hash },
`Finished proxying SSE stream.`
);
res.end();
resolve(finalBody);
});
})
);
proxyRes.on("error", (err) => {
req.log.error(
{ error: err, api: req.api, key: req.key?.hash },
`Error while streaming response.`
);
// OAI's spec doesn't allow for error events and clients wouldn't know
// what to do with them anyway, so we'll just send a completion event
// with the error message.
const fakeErrorEvent = {
id: "chatcmpl-error",
object: "chat.completion.chunk",
created: Date.now(),
model: "",
choices: [
{
delta: { content: "[Proxy streaming error: " + err.message + "]" },
index: 0,
finish_reason: "error",
},
],
};
const fakeErrorEvent = buildFakeSseMessage(
"mid-stream-error",
err.message,
req
);
res.write(`data: ${JSON.stringify(fakeErrorEvent)}\n\n`);
res.write("data: [DONE]\n\n");
res.end();
@ -97,8 +169,57 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
});
};
/**
* Transforms SSE events from the given response API into events compatible with
* the API requested by the client.
*/
function transformEvent(
data: string,
requestApi: string,
responseApi: string,
lastPosition: number
) {
if (requestApi === responseApi) {
return { position: -1, event: data };
}
if (requestApi === "anthropic" && responseApi === "openai") {
throw new Error(`Anthropic -> OpenAI streaming not implemented.`);
}
// Anthropic sends the full completion so far with each event whereas OpenAI
// only sends the delta. To make the SSE events compatible, we remove
// everything before `lastPosition` from the completion.
if (!data.startsWith("data:")) {
return { position: lastPosition, event: data };
}
if (data.startsWith("data: [DONE]")) {
return { position: lastPosition, event: data };
}
const event = JSON.parse(data.slice("data: ".length));
const newEvent = {
id: "ant-" + event.log_id,
object: "chat.completion.chunk",
created: Date.now(),
model: event.model,
choices: [
{
index: 0,
delta: { content: event.completion?.slice(lastPosition) },
finish_reason: event.stop_reason,
},
],
};
return {
position: event.completion.length,
event: `data: ${JSON.stringify(newEvent)}`,
};
}
/** Copy headers, excluding ones we're already setting for the SSE response. */
const copyHeaders = (proxyRes: http.IncomingMessage, res: Response) => {
function copyHeaders(proxyRes: http.IncomingMessage, res: Response) {
const toOmit = [
"content-length",
"content-encoding",
@ -112,22 +233,10 @@ const copyHeaders = (proxyRes: http.IncomingMessage, res: Response) => {
res.setHeader(key, value);
}
}
};
}
type OpenAiChatCompletionResponse = {
id: string;
object: string;
created: number;
model: string;
choices: {
message: { role: string; content: string };
finish_reason: string | null;
index: number;
}[];
};
/** Converts the event stream chunks into a single completion response. */
const convertEventsToOpenAiResponse = (chunks: Buffer[]) => {
function convertEventsToFinalResponse(events: string[], req: Request) {
if (req.key!.service === "openai") {
let response: OpenAiChatCompletionResponse = {
id: "",
object: "",
@ -135,22 +244,16 @@ const convertEventsToOpenAiResponse = (chunks: Buffer[]) => {
model: "",
choices: [],
};
const events = Buffer.concat(chunks)
.toString()
.trim()
.split("\n\n")
.map((line) => line.trim());
response = events.reduce((acc, chunk, i) => {
if (!chunk.startsWith("data: ")) {
response = events.reduce((acc, event, i) => {
if (!event.startsWith("data: ")) {
return acc;
}
if (chunk === "data: [DONE]") {
if (event === "data: [DONE]") {
return acc;
}
const data = JSON.parse(chunk.slice("data: ".length));
const data = JSON.parse(event.slice("data: ".length));
if (i === 0) {
return {
id: data.id,
@ -174,4 +277,19 @@ const convertEventsToOpenAiResponse = (chunks: Buffer[]) => {
return acc;
}, response);
return response;
};
}
if (req.key!.service === "anthropic") {
/*
* Full complete responses from Anthropic are conveniently just the same as
* the final SSE event before the "DONE" event, so we can reuse that
*/
const lastEvent = events[events.length - 2].toString();
const data = JSON.parse(lastEvent.slice("data: ".length));
const response: AnthropicCompletionResponse = {
...data,
log_id: req.id,
};
return response;
}
throw new Error("If you get this, something is fucked");
}

View File

@ -1,17 +1,19 @@
/* This file is fucking horrendous, sorry */
import { Request, Response } from "express";
import * as http from "http";
import * as httpProxy from "http-proxy";
import util from "util";
import zlib from "zlib";
import * as httpProxy from "http-proxy";
import { ZodError } from "zod";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { keyPool } from "../../../key-management";
import { incrementPromptCount } from "../../auth/user-store";
import { buildFakeSseMessage, enqueue, trackWaitTime } from "../../queue";
import { isCompletionRequest } from "../request";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
import { incrementPromptCount } from "../../auth/user-store";
export const QUOTA_ROUTES = ["/v1/chat/completions"];
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
@ -174,7 +176,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
} else {
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
logger.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
writeErrorResponse(res, 500, {
writeErrorResponse(req, res, 500, {
error: errorMessage,
contentEncoding,
});
@ -191,7 +193,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
} catch (error: any) {
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
logger.warn({ error, key: req.key?.hash }, errorMessage);
writeErrorResponse(res, 500, { error: errorMessage });
writeErrorResponse(req, res, 500, { error: errorMessage });
return reject(errorMessage);
}
});
@ -199,8 +201,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
return promise;
};
// TODO: This is too specific to OpenAI's error responses, Anthropic errors
// will need a different handler.
// TODO: This is too specific to OpenAI's error responses.
/**
* Handles non-2xx responses from the upstream service. If the proxied response
* is an error, this will respond to the client with an error payload and throw
@ -237,7 +238,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
}
} catch (parseError: any) {
const statusMessage = proxyRes.statusMessage || "Unknown error";
// Likely Bad Gateway or Gateway Timeout from OpenAI's Cloudflare proxy
// Likely Bad Gateway or Gateway Timeout from reverse proxy/load balancer
logger.warn(
{ statusCode, statusMessage, key: req.key?.hash },
parseError.message
@ -249,7 +250,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
error: parseError.message,
proxy_note: `This is likely a temporary error with the upstream service.`,
};
writeErrorResponse(res, statusCode, errorObject);
writeErrorResponse(req, res, statusCode, errorObject);
throw new Error(parseError.message);
}
@ -265,12 +266,71 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
if (statusCode === 400) {
// Bad request (likely prompt is too long)
errorPayload.proxy_note = `OpenAI rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
} else if (statusCode === 401) {
// Key is invalid or was revoked
keyPool.disable(req.key!);
errorPayload.proxy_note = `The OpenAI key is invalid or revoked. ${tryAgainMessage}`;
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
} else if (statusCode === 429) {
// OpenAI uses this for a bunch of different rate-limiting scenarios.
if (req.key!.service === "openai") {
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
} else {
handleAnthropicRateLimitError(req, errorPayload);
}
} else if (statusCode === 404) {
// Most likely model not found
if (req.key!.service === "openai") {
// TODO: this probably doesn't handle GPT-4-32k variants properly if the
// proxy has keys for both the 8k and 32k context models at the same time.
if (errorPayload.error?.code === "model_not_found") {
if (req.key!.isGpt4) {
errorPayload.proxy_note = `Assigned key isn't provisioned for the GPT-4 snapshot you requested. Try again to get a different key, or use Turbo.`;
} else {
errorPayload.proxy_note = `No model was found for this key.`;
}
}
} else if (req.key!.service === "anthropic") {
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
}
} else {
errorPayload.proxy_note = `Unrecognized error from upstream service.`;
}
// Some OAI errors contain the organization ID, which we don't want to reveal.
if (errorPayload.error?.message) {
errorPayload.error.message = errorPayload.error.message.replace(
/org-.{24}/gm,
"org-xxxxxxxxxxxxxxxxxxx"
);
}
writeErrorResponse(req, res, statusCode, errorPayload);
throw new Error(errorPayload.error?.message);
};
function handleAnthropicRateLimitError(
req: Request,
errorPayload: Record<string, any>
) {
//{"error":{"type":"rate_limit_error","message":"Number of concurrent connections to Claude exceeds your rate limit. Please try again, or contact sales@anthropic.com to discuss your options for a rate limit increase."}}
if (errorPayload.error?.type === "rate_limit_error") {
keyPool.markRateLimited(req.key!);
if (config.queueMode !== "none") {
reenqueueRequest(req);
throw new RetryableError("Claude rate-limited request re-enqueued.");
}
errorPayload.proxy_note = `There are too many in-flight requests for this key. Try again later.`;
} else {
errorPayload.proxy_note = `Unrecognized rate limit error from Anthropic. Key may be over quota.`;
}
}
function handleOpenAIRateLimitError(
req: Request,
tryAgainMessage: string,
errorPayload: Record<string, any>
): Record<string, any> {
const type = errorPayload.error?.type;
if (type === "insufficient_quota") {
// Billing quota exceeded (key is dead, disable it)
@ -282,10 +342,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`;
} else if (type === "requests" || type === "tokens") {
// Per-minute request or token rate limit is exceeded, which we can retry
keyPool.markRateLimited(req.key!.hash);
keyPool.markRateLimited(req.key!);
if (config.queueMode !== "none") {
reenqueueRequest(req);
// TODO: I don't like using an error to control flow here
// This is confusing, but it will bubble up to the top-level response
// handler and cause the request to go back into the request queue.
throw new RetryableError("Rate-limited request re-enqueued.");
}
errorPayload.proxy_note = `Assigned key's '${type}' rate limit has been exceeded. Try again later.`;
@ -293,38 +354,19 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
// OpenAI probably overloaded
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
}
} else if (statusCode === 404) {
// Most likely model not found
// TODO: this probably doesn't handle GPT-4-32k variants properly if the
// proxy has keys for both the 8k and 32k context models at the same time.
if (errorPayload.error?.code === "model_not_found") {
if (req.key!.isGpt4) {
errorPayload.proxy_note = `Assigned key isn't provisioned for the GPT-4 snapshot you requested. Try again to get a different key, or use Turbo.`;
} else {
errorPayload.proxy_note = `No model was found for this key.`;
}
}
} else {
errorPayload.proxy_note = `Unrecognized error from OpenAI.`;
}
// Some OAI errors contain the organization ID, which we don't want to reveal.
if (errorPayload.error?.message) {
errorPayload.error.message = errorPayload.error.message.replace(
/org-.{24}/gm,
"org-xxxxxxxxxxxxxxxxxxx"
);
}
writeErrorResponse(res, statusCode, errorPayload);
throw new Error(errorPayload.error?.message);
};
return errorPayload;
}
function writeErrorResponse(
req: Request,
res: Response,
statusCode: number,
errorPayload: Record<string, any>
) {
const errorSource = errorPayload.error?.type.startsWith("proxy")
? "proxy"
: "upstream";
// If we're mid-SSE stream, send a data event with the error payload and end
// the stream. Otherwise just send a normal error response.
if (
@ -332,8 +374,9 @@ function writeErrorResponse(
res.getHeader("content-type") === "text/event-stream"
) {
const msg = buildFakeSseMessage(
`upstream error (${statusCode})`,
JSON.stringify(errorPayload, null, 2)
`${errorSource} error (${statusCode})`,
JSON.stringify(errorPayload, null, 2),
req
);
res.write(msg);
res.write(`data: [DONE]\n\n`);
@ -344,21 +387,31 @@ function writeErrorResponse(
}
/** Handles errors in rewriter pipelines. */
export const handleInternalError: httpProxy.ErrorCallback = (
err,
_req,
res
) => {
export const handleInternalError: httpProxy.ErrorCallback = (err, req, res) => {
logger.error({ error: err }, "Error in http-proxy-middleware pipeline.");
try {
writeErrorResponse(res as Response, 500, {
const isZod = err instanceof ZodError;
if (isZod) {
writeErrorResponse(req as Request, res as Response, 400, {
error: {
type: "proxy_error",
message: err.message,
type: "proxy_validation_error",
proxy_note: `Reverse proxy couldn't validate your request when trying to transform it. Your client may be sending invalid data.`,
issues: err.issues,
stack: err.stack,
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
message: err.message,
},
});
} else {
writeErrorResponse(req as Request, res as Response, 500, {
error: {
type: "proxy_rewriter_error",
proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`,
message: err.message,
stack: err.stack,
},
});
}
} catch (e) {
logger.error(
{ error: e },
@ -368,8 +421,8 @@ export const handleInternalError: httpProxy.ErrorCallback = (
};
const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (QUOTA_ROUTES.includes(req.path)) {
keyPool.incrementPrompt(req.key?.hash);
if (isCompletionRequest(req)) {
keyPool.incrementPrompt(req.key!);
if (req.user) {
incrementPromptCount(req.user.token);
}
@ -377,7 +430,7 @@ const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
};
const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
keyPool.updateRateLimits(req.key!.hash, proxyRes.headers);
keyPool.updateRateLimits(req.key!, proxyRes.headers);
};
const copyHttpHeaders: ProxyResHandlerWithBody = async (

View File

@ -1,4 +1,5 @@
import { config } from "../../../config";
import { AIService } from "../../../key-management";
import { logQueue } from "../../../prompt-logging";
import { isCompletionRequest } from "../request";
import { ProxyResHandlerWithBody } from ".";
@ -17,18 +18,16 @@ export const logPrompt: ProxyResHandlerWithBody = async (
throw new Error("Expected body to be an object");
}
// Only log prompts if we're making a request to a completion endpoint
if (!isCompletionRequest(req)) {
// Remove this once we're confident that we're not missing any prompts
req.log.info(
`Not logging prompt for ${req.path} because it's not a completion endpoint`
);
return;
}
const model = req.body.model;
const promptFlattened = flattenMessages(req.body.messages);
const response = getResponseForModel({ model, body: responseBody });
const response = getResponseForService({
service: req.key!.service,
body: responseBody,
});
logQueue.enqueue({
model,
@ -48,15 +47,14 @@ const flattenMessages = (messages: OaiMessage[]): string => {
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
};
const getResponseForModel = ({
model,
const getResponseForService = ({
service,
body,
}: {
model: string;
service: AIService;
body: Record<string, any>;
}) => {
if (model.startsWith("claude")) {
// TODO: confirm if there is supposed to be a leading space
if (service === "anthropic") {
return body.completion.trim();
} else {
return body.choices[0].message.content;

View File

@ -8,10 +8,10 @@ import { ipLimiter } from "./rate-limit";
import {
addKey,
languageFilter,
checkStreaming,
finalizeBody,
limitOutputTokens,
limitCompletions,
transformOutboundPayload,
} from "./middleware/request";
import {
createOnProxyResHandler,
@ -28,9 +28,9 @@ const rewriteRequest = (
const rewriterPipeline = [
addKey,
languageFilter,
checkStreaming,
limitOutputTokens,
limitCompletions,
transformOutboundPayload,
finalizeBody,
];
@ -39,7 +39,7 @@ const rewriteRequest = (
rewriter(proxyReq, req, res, {});
}
} catch (error) {
logger.error(error, "Error while executing proxy rewriter");
req.log.error(error, "Error while executing proxy rewriter");
proxyReq.destroy(error as Error);
}
};
@ -98,7 +98,7 @@ openaiRouter.get("*", (req, res, next) => {
}
});
openaiRouter.use((req, res) => {
logger.warn(`Blocked openai proxy request: ${req.method} ${req.path}`);
req.log.warn(`Blocked openai proxy request: ${req.method} ${req.path}`);
res.status(404).json({ error: "Not found" });
});

View File

@ -17,7 +17,7 @@
import type { Handler, Request } from "express";
import { config, DequeueMode } from "../config";
import { keyPool } from "../key-management";
import { keyPool, SupportedModel } from "../key-management";
import { logger } from "../logger";
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
@ -78,7 +78,7 @@ export function enqueue(req: Request) {
// If the request opted into streaming, we need to register a heartbeat
// handler to keep the connection alive while it waits in the queue. We
// deregister the handler when the request is dequeued.
if (req.body.stream) {
if (req.body.stream === "true" || req.body.stream === true) {
const res = req.res!;
if (!res.headersSent) {
initStreaming(req);
@ -91,7 +91,7 @@ export function enqueue(req: Request) {
const avgWait = Math.round(getEstimatedWaitTime() / 1000);
const currentDuration = Math.round((Date.now() - req.startTime) / 1000);
const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`;
req.res!.write(buildFakeSseMessage("heartbeat", debugMsg));
req.res!.write(buildFakeSseMessage("heartbeat", debugMsg, req));
}
}, 10000);
}
@ -118,12 +118,24 @@ export function enqueue(req: Request) {
}
}
export function dequeue(model: string): Request | undefined {
// TODO: This should be set by some middleware that checks the request body.
const modelQueue =
model === "gpt-4"
? queue.filter((req) => req.body.model?.startsWith("gpt-4"))
: queue.filter((req) => !req.body.model?.startsWith("gpt-4"));
export function dequeue(model: SupportedModel): Request | undefined {
const modelQueue = queue.filter((req) => {
const reqProvider = req.originalUrl.startsWith("/proxy/anthropic")
? "anthropic"
: "openai";
// This sucks, but the `req.body.model` on Anthropic requests via the
// OpenAI-compat endpoint isn't actually claude-*, it's a fake gpt value.
// TODO: refactor model/service detection
if (model.startsWith("claude")) {
return reqProvider === "anthropic";
}
if (model.startsWith("gpt-4")) {
return reqProvider === "openai" && req.body.model?.startsWith("gpt-4");
}
return reqProvider === "openai" && req.body.model?.startsWith("gpt-3");
});
if (modelQueue.length === 0) {
return undefined;
@ -172,6 +184,7 @@ function processQueue() {
// the others, because we only track one rate limit per key.
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
const reqs: (Request | undefined)[] = [];
if (gpt4Lockout === 0) {
@ -180,6 +193,9 @@ function processQueue() {
if (turboLockout === 0) {
reqs.push(dequeue("gpt-3.5-turbo"));
}
if (claudeLockout === 0) {
reqs.push(dequeue("claude-v1"));
}
reqs.filter(Boolean).forEach((req) => {
if (req?.proceed) {
@ -266,7 +282,7 @@ export function createQueueMiddleware(proxyMiddleware: Handler): Handler {
type: "proxy_error",
message: err.message,
stack: err.stack,
proxy_note: `Only one request per IP can be queued at a time. If you don't have another request queued, your IP may be in use by another user.`,
proxy_note: `Only one request can be queued at a time. If you don't have another request queued, your IP or user token might be in use by another request.`,
});
}
};
@ -281,7 +297,11 @@ function killQueuedRequest(req: Request) {
try {
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`;
if (res.headersSent) {
const fakeErrorEvent = buildFakeSseMessage("proxy queue error", message);
const fakeErrorEvent = buildFakeSseMessage(
"proxy queue error",
message,
req
);
res.write(fakeErrorEvent);
res.end();
} else {
@ -305,12 +325,29 @@ function initStreaming(req: Request) {
res.write(": joining queue\n\n");
}
export function buildFakeSseMessage(type: string, string: string) {
const fakeEvent = {
id: "chatcmpl-" + type,
export function buildFakeSseMessage(
type: string,
string: string,
req: Request
) {
let fakeEvent;
if (req.api === "anthropic") {
// data: {"completion": " Here is a paragraph of lorem ipsum text:\n\nLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor inc", "stop_reason": "max_tokens", "truncated": false, "stop": null, "model": "claude-instant-v1", "log_id": "???", "exception": null}
fakeEvent = {
completion: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`,
stop_reason: type,
truncated: false, // I've never seen this be true
stop: null,
model: req.body?.model,
log_id: "proxy-req-" + req.id,
};
} else {
fakeEvent = {
id: "chatcmpl-" + req.id,
object: "chat.completion.chunk",
created: Date.now(),
model: "",
model: req.body?.model,
choices: [
{
delta: { content: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n` },
@ -319,6 +356,7 @@ export function buildFakeSseMessage(type: string, string: string) {
},
],
};
}
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
}

View File

@ -8,12 +8,14 @@ import * as express from "express";
import { gatekeeper } from "./auth/gatekeeper";
import { kobold } from "./kobold";
import { openai } from "./openai";
import { anthropic } from "./anthropic";
const router = express.Router();
router.use(gatekeeper);
router.use("/kobold", kobold);
router.use("/openai", openai);
router.use("/anthropic", anthropic);
// Each client handles the endpoints input by the user in slightly different
// ways, eg TavernAI ignores everything after the hostname in Kobold mode

View File

@ -1,11 +1,16 @@
import { Express } from "express-serve-static-core";
import { Key } from "../key-management/key-pool";
import { Key } from "../key-management/index";
import { User } from "../proxy/auth/user-store";
declare global {
namespace Express {
interface Request {
key?: Key;
/**
* Denotes the _inbound_ API format. This is used to determine how the
* user has submitted their request; the proxy will then translate the
* paramaters to the target API format, which is on `key.service`.
*/
api: "kobold" | "openai" | "anthropic";
user?: User;
isStreaming?: boolean;