From 655703e68042013ee2597639580436ce53a53b59 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Sat, 16 Dec 2023 20:30:20 -0600 Subject: [PATCH] refactors infopage --- scripts/oai-reverse-proxy.http | 34 +- src/admin/routes.ts | 5 +- src/config.ts | 31 +- src/info-page.ts | 545 ++---------------- src/proxy/middleware/common.ts | 9 +- .../onproxyreq/block-zoomer-origins.ts | 6 +- .../request/onproxyreq/check-model-family.ts | 5 +- .../request/preprocessors/set-api-format.ts | 7 +- src/server.ts | 8 +- src/service-info.ts | 417 ++++++++++++++ src/{types => shared}/custom.d.ts | 8 +- src/shared/key-management/index.ts | 9 +- src/shared/key-management/key-pool.ts | 31 +- src/shared/models.ts | 35 +- src/shared/stats.ts | 6 + tsconfig.json | 2 +- 16 files changed, 584 insertions(+), 574 deletions(-) create mode 100644 src/service-info.ts rename src/{types => shared}/custom.d.ts (85%) diff --git a/scripts/oai-reverse-proxy.http b/scripts/oai-reverse-proxy.http index f3c8662..381fb2b 100644 --- a/scripts/oai-reverse-proxy.http +++ b/scripts/oai-reverse-proxy.http @@ -81,7 +81,7 @@ Authorization: Bearer {{proxy-key}} Content-Type: application/json { - "model": "gpt-3.5-turbo", + "model": "gpt-4-1106-preview", "max_tokens": 20, "stream": true, "temperature": 1, @@ -231,8 +231,36 @@ Content-Type: application/json } ### -# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation -POST {{proxy-host}}/proxy/google-palm/v1/chat/completions +# @name Proxy / Azure OpenAI -- Native Chat Completions +POST {{proxy-host}}/proxy/azure/openai/chat/completions +Authorization: Bearer {{proxy-key}} +Content-Type: application/json + +{ + "model": "gpt-4", + "max_tokens": 20, + "stream": true, + "temperature": 1, + "seed": 2, + "messages": [ + { + "role": "user", + "content": "Hi what is the name of the fourth president of the united states?" + }, + { + "role": "assistant", + "content": "That would be George Washington." + }, + { + "role": "user", + "content": "That's not right." + } + ] +} + +### +# @name Proxy / Google AI -- OpenAI-to-Google AI API Translation +POST {{proxy-host}}/proxy/google-ai/v1/chat/completions Authorization: Bearer {{proxy-key}} Content-Type: application/json diff --git a/src/admin/routes.ts b/src/admin/routes.ts index 2ae71d0..fefc420 100644 --- a/src/admin/routes.ts +++ b/src/admin/routes.ts @@ -4,7 +4,8 @@ import { HttpError } from "../shared/errors"; import { injectLocals } from "../shared/inject-locals"; import { withSession } from "../shared/with-session"; import { injectCsrfToken, checkCsrfToken } from "../shared/inject-csrf"; -import { buildInfoPageHtml } from "../info-page"; +import { renderPage } from "../info-page"; +import { buildInfo } from "../service-info"; import { loginRouter } from "./login"; import { usersApiRouter as apiRouter } from "./api/users"; import { usersWebRouter as webRouter } from "./web/manage"; @@ -26,7 +27,7 @@ adminRouter.use("/", loginRouter); adminRouter.use("/manage", authorize({ via: "cookie" }), webRouter); adminRouter.use("/service-info", authorize({ via: "cookie" }), (req, res) => { return res.send( - buildInfoPageHtml(req.protocol + "://" + req.get("host"), true) + renderPage(buildInfo(req.protocol + "://" + req.get("host"), true)) ); }); diff --git a/src/config.ts b/src/config.ts index 5a63700..b4f7f28 100644 --- a/src/config.ts +++ b/src/config.ts @@ -4,6 +4,7 @@ import path from "path"; import pino from "pino"; import type { ModelFamily } from "./shared/models"; import { MODEL_FAMILIES } from "./shared/models"; + dotenv.config(); const startupLogger = pino({ level: "debug" }).child({ module: "startup" }); @@ -365,7 +366,7 @@ export const SENSITIVE_KEYS: (keyof Config)[] = ["googleSheetsSpreadsheetId"]; * Config keys that are not displayed on the info page at all, generally because * they are not relevant to the user or can be inferred from other config. */ -export const OMITTED_KEYS: (keyof Config)[] = [ +export const OMITTED_KEYS = [ "port", "logLevel", "openaiKey", @@ -391,34 +392,46 @@ export const OMITTED_KEYS: (keyof Config)[] = [ "staticServiceInfo", "checkKeys", "allowedModelFamilies", -]; +] satisfies (keyof Config)[]; +type OmitKeys = (typeof OMITTED_KEYS)[number]; + +type Printable = { + [P in keyof T as Exclude]: T[P] extends object + ? Printable + : string; +}; +type PublicConfig = Printable; const getKeys = Object.keys as (obj: T) => Array; -export function listConfig(obj: Config = config): Record { - const result: Record = {}; +export function listConfig(obj: Config = config) { + const result: Record = {}; for (const key of getKeys(obj)) { const value = obj[key]?.toString() || ""; - const shouldOmit = - OMITTED_KEYS.includes(key) || value === "" || value === "undefined"; const shouldMask = SENSITIVE_KEYS.includes(key); + const shouldOmit = + OMITTED_KEYS.includes(key as OmitKeys) || + value === "" || + value === "undefined"; if (shouldOmit) { continue; } + const validKey = key as keyof Printable; + if (value && shouldMask) { - result[key] = "********"; + result[validKey] = "********"; } else { - result[key] = value; + result[validKey] = value; } if (typeof obj[key] === "object" && !Array.isArray(obj[key])) { result[key] = listConfig(obj[key] as unknown as Config); } } - return result; + return result as PublicConfig; } /** diff --git a/src/info-page.ts b/src/info-page.ts index ed7b6c7..05799ea 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -1,74 +1,36 @@ -/** This whole module really sucks */ +/** This whole module kinda sucks */ import fs from "fs"; import { Request, Response } from "express"; import showdown from "showdown"; -import { config, listConfig } from "./config"; -import { - AnthropicKey, - AwsBedrockKey, - AzureOpenAIKey, - GoogleAIKey, - keyPool, - OpenAIKey, -} from "./shared/key-management"; -import { - AzureOpenAIModelFamily, - ModelFamily, - OpenAIModelFamily, -} from "./shared/models"; -import { getUniqueIps } from "./proxy/rate-limit"; -import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue"; -import { getTokenCostUsd, prettyTokens } from "./shared/stats"; -import { assertNever } from "./shared/utils"; +import { config } from "./config"; +import { buildInfo, ServiceInfo } from "./service-info"; import { getLastNImages } from "./shared/file-storage/image-history"; +import { keyPool } from "./shared/key-management"; +import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models"; const INFO_PAGE_TTL = 2000; +const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { + "turbo": "GPT-3.5 Turbo", + "gpt4": "GPT-4", + "gpt4-32k": "GPT-4 32k", + "gpt4-turbo": "GPT-4 Turbo", + "dall-e": "DALL-E", + "claude": "Claude", + "gemini-pro": "Gemini Pro", + "aws-claude": "AWS Claude", + "azure-turbo": "Azure GPT-3.5 Turbo", + "azure-gpt4": "Azure GPT-4", + "azure-gpt4-32k": "Azure GPT-4 32k", + "azure-gpt4-turbo": "Azure GPT-4 Turbo", +}; + +const converter = new showdown.Converter(); +const customGreeting = fs.existsSync("greeting.md") + ? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}` + : ""; let infoPageHtml: string | undefined; let infoPageLastUpdated = 0; -type KeyPoolKey = ReturnType[0]; -const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey => - k.service === "openai"; -const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey => - k.service === "azure"; -const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey => - k.service === "anthropic"; -const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey => - k.service === "google-ai"; -const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws"; - -type ModelAggregates = { - active: number; - trial?: number; - revoked?: number; - overQuota?: number; - pozzed?: number; - awsLogged?: number; - queued: number; - queueTime: string; - tokens: number; -}; -type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`; -type ServiceAggregates = { - status?: string; - openaiKeys?: number; - openaiOrgs?: number; - anthropicKeys?: number; - googleAIKeys?: number; - awsKeys?: number; - azureKeys?: number; - proompts: number; - tokens: number; - tokenCost: number; - openAiUncheckedKeys?: number; - anthropicUncheckedKeys?: number; -} & { - [modelFamily in ModelFamily]?: ModelAggregates; -}; - -const modelStats = new Map(); -const serviceStats = new Map(); - export const handleInfoPage = (req: Request, res: Response) => { if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) { return res.send(infoPageHtml); @@ -79,93 +41,16 @@ export const handleInfoPage = (req: Request, res: Response) => { ? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID) : req.protocol + "://" + req.get("host"); - infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy"); + const info = buildInfo(baseUrl + "/proxy"); + infoPageHtml = renderPage(info); infoPageLastUpdated = Date.now(); res.send(infoPageHtml); }; -function getCostString(cost: number) { - if (!config.showTokenCosts) return ""; - return ` ($${cost.toFixed(2)})`; -} - -export function buildInfoPageHtml(baseUrl: string, asAdmin = false) { - const keys = keyPool.list(); - const hideFullInfo = config.staticServiceInfo && !asAdmin; - - modelStats.clear(); - serviceStats.clear(); - keys.forEach(addKeyToAggregates); - - const openaiKeys = serviceStats.get("openaiKeys") || 0; - const anthropicKeys = serviceStats.get("anthropicKeys") || 0; - const googleAIKeys = serviceStats.get("googleAIKeys") || 0; - const awsKeys = serviceStats.get("awsKeys") || 0; - const azureKeys = serviceStats.get("azureKeys") || 0; - const proompts = serviceStats.get("proompts") || 0; - const tokens = serviceStats.get("tokens") || 0; - const tokenCost = serviceStats.get("tokenCost") || 0; - - const allowDalle = config.allowedModelFamilies.includes("dall-e"); - - const endpoints = { - ...(openaiKeys ? { openai: baseUrl + "/openai" } : {}), - ...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}), - ...(openaiKeys && allowDalle - ? { ["openai-image"]: baseUrl + "/openai-image" } - : {}), - ...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}), - ...(googleAIKeys ? { "google-ai": baseUrl + "/google-ai" } : {}), - ...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}), - ...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}), - }; - - const stats = { - proompts, - tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`, - ...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}), - }; - - const keyInfo = { - openaiKeys, - anthropicKeys, - googleAIKeys, - awsKeys, - azureKeys, - }; - for (const key of Object.keys(keyInfo)) { - if (!(keyInfo as any)[key]) delete (keyInfo as any)[key]; - } - - const providerInfo = { - ...(openaiKeys ? getOpenAIInfo() : {}), - ...(anthropicKeys ? getAnthropicInfo() : {}), - ...(googleAIKeys ? getGoogleAIInfo() : {}), - ...(awsKeys ? getAwsInfo() : {}), - ...(azureKeys ? getAzureInfo() : {}), - }; - - if (hideFullInfo) { - for (const provider of Object.keys(providerInfo)) { - delete (providerInfo as any)[provider].proomptersInQueue; - delete (providerInfo as any)[provider].estimatedQueueTime; - delete (providerInfo as any)[provider].usage; - } - } - - const info = { - uptime: Math.floor(process.uptime()), - endpoints, - ...(hideFullInfo ? {} : stats), - ...keyInfo, - ...providerInfo, - config: listConfig(), - build: process.env.BUILD_INFO || "dev", - }; - +export function renderPage(info: ServiceInfo) { const title = getServerTitle(); - const headerHtml = buildInfoPageHeader(new showdown.Converter(), title); + const headerHtml = buildInfoPageHeader(info); return ` @@ -184,324 +69,14 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) { `; } -function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) { - const orgIds = new Set( - keys.filter((k) => k.service === "openai").map((k: any) => k.organizationId) - ); - return orgIds.size; -} - -function increment( - map: Map, - key: T, - delta = 1 -) { - map.set(key, (map.get(key) || 0) + delta); -} - -function addKeyToAggregates(k: KeyPoolKey) { - increment(serviceStats, "proompts", k.promptCount); - increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0); - increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0); - increment(serviceStats, "googleAIKeys", k.service === "google-ai" ? 1 : 0); - increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0); - increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0); - - let sumTokens = 0; - let sumCost = 0; - - switch (k.service) { - case "openai": - if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type"); - increment( - serviceStats, - "openAiUncheckedKeys", - Boolean(k.lastChecked) ? 0 : 1 - ); - - k.modelFamilies.forEach((f) => { - const tokens = k[`${f}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(f, tokens); - increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0); - increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0); - }); - break; - case "azure": - if (!keyIsAzureKey(k)) throw new Error("Invalid key type"); - k.modelFamilies.forEach((f) => { - const tokens = k[`${f}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(f, tokens); - increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - }); - break; - case "anthropic": { - if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type"); - const family = "claude"; - sumTokens += k.claudeTokens; - sumCost += getTokenCostUsd(family, k.claudeTokens); - increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${family}__tokens`, k.claudeTokens); - increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0); - increment( - serviceStats, - "anthropicUncheckedKeys", - Boolean(k.lastChecked) ? 0 : 1 - ); - break; - } - case "google-ai": { - if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type"); - const family = "gemini-pro"; - sumTokens += k["gemini-proTokens"]; - sumCost += getTokenCostUsd(family, k["gemini-proTokens"]); - increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]); - break; - } - case "aws": { - if (!keyIsAwsKey(k)) throw new Error("Invalid key type"); - const family = "aws-claude"; - sumTokens += k["aws-claudeTokens"]; - sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]); - increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]); - - // Ignore revoked keys for aws logging stats, but include keys where the - // logging status is unknown. - const countAsLogged = - k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled"; - increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0); - - break; - } - default: - assertNever(k.service); - } - - increment(serviceStats, "tokens", sumTokens); - increment(serviceStats, "tokenCost", sumCost); -} - -function getOpenAIInfo() { - const info: { status?: string; openaiKeys?: number; openaiOrgs?: number } & { - [modelFamily in OpenAIModelFamily]?: { - usage?: string; - activeKeys: number; - trialKeys?: number; - revokedKeys?: number; - overQuotaKeys?: number; - proomptersInQueue?: number; - estimatedQueueTime?: string; - }; - } = {}; - - const keys = keyPool.list().filter(keyIsOpenAIKey); - const enabledFamilies = new Set(config.allowedModelFamilies); - const accessibleFamilies = keys - .flatMap((k) => k.modelFamilies) - .filter((f) => enabledFamilies.has(f)) - .concat("turbo"); - const familySet = new Set(accessibleFamilies); - - if (config.checkKeys) { - const unchecked = serviceStats.get("openAiUncheckedKeys") || 0; - if (unchecked > 0) { - info.status = `Checking ${unchecked} keys...`; - } - info.openaiKeys = keys.length; - info.openaiOrgs = getUniqueOpenAIOrgs(keys); - - familySet.forEach((f) => { - const tokens = modelStats.get(`${f}__tokens`) || 0; - const cost = getTokenCostUsd(f, tokens); - - info[f] = { - usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, - activeKeys: modelStats.get(`${f}__active`) || 0, - trialKeys: modelStats.get(`${f}__trial`) || 0, - revokedKeys: modelStats.get(`${f}__revoked`) || 0, - overQuotaKeys: modelStats.get(`${f}__overQuota`) || 0, - }; - - // Don't show trial/revoked keys for non-turbo families. - // Generally those stats only make sense for the lowest-tier model. - if (f !== "turbo") { - delete info[f]!.trialKeys; - delete info[f]!.revokedKeys; - } - }); - } else { - info.status = "Key checking is disabled."; - info.turbo = { activeKeys: keys.filter((k) => !k.isDisabled).length }; - info.gpt4 = { - activeKeys: keys.filter( - (k) => !k.isDisabled && k.modelFamilies.includes("gpt4") - ).length, - }; - } - - familySet.forEach((f) => { - if (enabledFamilies.has(f)) { - if (!info[f]) info[f] = { activeKeys: 0 }; // may occur if checkKeys is disabled - const { estimatedQueueTime, proomptersInQueue } = getQueueInformation(f); - info[f]!.proomptersInQueue = proomptersInQueue; - info[f]!.estimatedQueueTime = estimatedQueueTime; - } else { - (info[f]! as any).status = "GPT-3.5-Turbo is disabled on this proxy."; - } - }); - - return info; -} - -function getAnthropicInfo() { - const claudeInfo: Partial = { - active: modelStats.get("claude__active") || 0, - pozzed: modelStats.get("claude__pozzed") || 0, - revoked: modelStats.get("claude__revoked") || 0, - }; - - const queue = getQueueInformation("claude"); - claudeInfo.queued = queue.proomptersInQueue; - claudeInfo.queueTime = queue.estimatedQueueTime; - - const tokens = modelStats.get("claude__tokens") || 0; - const cost = getTokenCostUsd("claude", tokens); - - const unchecked = - (config.checkKeys && serviceStats.get("anthropicUncheckedKeys")) || 0; - - return { - claude: { - usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, - ...(unchecked > 0 ? { status: `Checking ${unchecked} keys...` } : {}), - activeKeys: claudeInfo.active, - revokedKeys: claudeInfo.revoked, - ...(config.checkKeys ? { pozzedKeys: claudeInfo.pozzed } : {}), - proomptersInQueue: claudeInfo.queued, - estimatedQueueTime: claudeInfo.queueTime, - }, - }; -} - -function getGoogleAIInfo() { - const googleAIInfo: Partial = { - active: modelStats.get("gemini-pro__active") || 0, - revoked: modelStats.get("gemini-pro__revoked") || 0, - }; - - const queue = getQueueInformation("gemini-pro"); - googleAIInfo.queued = queue.proomptersInQueue; - googleAIInfo.queueTime = queue.estimatedQueueTime; - - const tokens = modelStats.get("gemini-pro__tokens") || 0; - const cost = getTokenCostUsd("gemini-pro", tokens); - - return { - gemini: { - usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, - activeKeys: googleAIInfo.active, - revokedKeys: googleAIInfo.revoked, - proomptersInQueue: googleAIInfo.queued, - estimatedQueueTime: googleAIInfo.queueTime, - }, - }; -} - -function getAwsInfo() { - const awsInfo: Partial = { - active: modelStats.get("aws-claude__active") || 0, - revoked: modelStats.get("aws-claude__revoked") || 0, - }; - - const queue = getQueueInformation("aws-claude"); - awsInfo.queued = queue.proomptersInQueue; - awsInfo.queueTime = queue.estimatedQueueTime; - - const tokens = modelStats.get("aws-claude__tokens") || 0; - const cost = getTokenCostUsd("aws-claude", tokens); - - const logged = modelStats.get("aws-claude__awsLogged") || 0; - const logMsg = config.allowAwsLogging - ? `${logged} active keys are potentially logged.` - : `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`; - - return { - "aws-claude": { - usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, - activeKeys: awsInfo.active, - revokedKeys: awsInfo.revoked, - proomptersInQueue: awsInfo.queued, - estimatedQueueTime: awsInfo.queueTime, - ...(logged > 0 ? { privacy: logMsg } : {}), - }, - }; -} - -function getAzureInfo() { - const azureFamilies = [ - "azure-turbo", - "azure-gpt4", - "azure-gpt4-turbo", - "azure-gpt4-32k", - ] as const; - - const azureInfo: { - [modelFamily in AzureOpenAIModelFamily]?: { - usage?: string; - activeKeys: number; - revokedKeys?: number; - proomptersInQueue?: number; - estimatedQueueTime?: string; - }; - } = {}; - for (const family of azureFamilies) { - const familyAllowed = config.allowedModelFamilies.includes(family); - const activeKeys = modelStats.get(`${family}__active`) || 0; - - if (!familyAllowed || activeKeys === 0) continue; - - azureInfo[family] = { - activeKeys, - revokedKeys: modelStats.get(`${family}__revoked`) || 0, - }; - - const queue = getQueueInformation(family); - azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue; - azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime; - - const tokens = modelStats.get(`${family}__tokens`) || 0; - const cost = getTokenCostUsd(family, tokens); - azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString( - cost - )}`; - } - - return azureInfo; -} - -const customGreeting = fs.existsSync("greeting.md") - ? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}` - : ""; - /** * If the server operator provides a `greeting.md` file, it will be included in * the rendered info page. **/ -function buildInfoPageHeader(converter: showdown.Converter, title: string) { +function buildInfoPageHeader(info: ServiceInfo) { + const title = getServerTitle(); // TODO: use some templating engine instead of this mess - let infoBody = ` -# ${title}`; + let infoBody = `# ${title}`; if (config.promptLogging) { infoBody += `\n## Prompt Logging Enabled This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps. @@ -516,45 +91,18 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon } const waits: string[] = []; - infoBody += `\n## Estimated Wait Times`; - if (config.openaiKey) { - // TODO: un-fuck this - const keys = keyPool.list().filter((k) => k.service === "openai"); + for (const modelFamily of config.allowedModelFamilies) { + const service = MODEL_FAMILY_SERVICE[modelFamily]; - const turboWait = getQueueInformation("turbo").estimatedQueueTime; - waits.push(`**Turbo:** ${turboWait}`); + const hasKeys = keyPool.list().some((k) => { + return k.service === service && k.modelFamilies.includes(modelFamily); + }); - const gpt4Wait = getQueueInformation("gpt4").estimatedQueueTime; - const hasGpt4 = keys.some((k) => k.modelFamilies.includes("gpt4")); - const allowedGpt4 = config.allowedModelFamilies.includes("gpt4"); - if (hasGpt4 && allowedGpt4) { - waits.push(`**GPT-4:** ${gpt4Wait}`); + const wait = info[modelFamily]?.estimatedQueueTime; + if (hasKeys && wait) { + waits.push(`**${MODEL_FAMILY_FRIENDLY_NAME[modelFamily] || modelFamily}**: ${wait}`); } - - const gpt432kWait = getQueueInformation("gpt4-32k").estimatedQueueTime; - const hasGpt432k = keys.some((k) => k.modelFamilies.includes("gpt4-32k")); - const allowedGpt432k = config.allowedModelFamilies.includes("gpt4-32k"); - if (hasGpt432k && allowedGpt432k) { - waits.push(`**GPT-4-32k:** ${gpt432kWait}`); - } - - const dalleWait = getQueueInformation("dall-e").estimatedQueueTime; - const hasDalle = keys.some((k) => k.modelFamilies.includes("dall-e")); - const allowedDalle = config.allowedModelFamilies.includes("dall-e"); - if (hasDalle && allowedDalle) { - waits.push(`**DALL-E:** ${dalleWait}`); - } - } - - if (config.anthropicKey) { - const claudeWait = getQueueInformation("claude").estimatedQueueTime; - waits.push(`**Claude:** ${claudeWait}`); - } - - if (config.awsCredentials) { - const awsClaudeWait = getQueueInformation("aws-claude").estimatedQueueTime; - waits.push(`**Claude (AWS):** ${awsClaudeWait}`); } infoBody += "\n\n" + waits.join(" / "); @@ -571,21 +119,6 @@ function getSelfServiceLinks() { return ``; } -/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */ -function getQueueInformation(partition: ModelFamily) { - const waitMs = getEstimatedWaitTime(partition); - const waitTime = - waitMs < 60000 - ? `${Math.round(waitMs / 1000)}sec` - : `${Math.round(waitMs / 60000)}min, ${Math.round( - (waitMs % 60000) / 1000 - )}sec`; - return { - proomptersInQueue: getQueueLength(partition), - estimatedQueueTime: waitMs > 2000 ? waitTime : "no wait", - }; -} - function getServerTitle() { // Use manually set title if available if (process.env.SERVER_TITLE) { diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 9355ded..c677e46 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -129,7 +129,7 @@ function classifyError(err: Error): { userMessage, type: "proxy_validation_error", }; - case "ForbiddenError": + case "ZoomerForbiddenError": // Mimics a ban notice from OpenAI, thrown when blockZoomerOrigins blocks // a request. return { @@ -139,6 +139,13 @@ function classifyError(err: Error): { type: "organization_account_disabled", code: "policy_violation", }; + case "ForbiddenError": + return { + statusCode: 403, + statusMessage: "Forbidden", + userMessage: `Request is not allowed. (${err.message})`, + type: "proxy_forbidden", + }; case "QuotaExceededError": return { statusCode: 429, diff --git a/src/proxy/middleware/request/onproxyreq/block-zoomer-origins.ts b/src/proxy/middleware/request/onproxyreq/block-zoomer-origins.ts index 0c8360d..25ff37c 100644 --- a/src/proxy/middleware/request/onproxyreq/block-zoomer-origins.ts +++ b/src/proxy/middleware/request/onproxyreq/block-zoomer-origins.ts @@ -2,10 +2,10 @@ import { HPMRequestCallback } from "../index"; const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(","); -class ForbiddenError extends Error { +class ZoomerForbiddenError extends Error { constructor(message: string) { super(message); - this.name = "ForbiddenError"; + this.name = "ZoomerForbiddenError"; } } @@ -22,7 +22,7 @@ export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => { return; } - throw new ForbiddenError( + throw new ZoomerForbiddenError( `Your access was terminated due to violation of our policies, please check your email for more information. If you believe this is in error and would like to appeal, please contact us through our help center at help.openai.com.` ); } diff --git a/src/proxy/middleware/request/onproxyreq/check-model-family.ts b/src/proxy/middleware/request/onproxyreq/check-model-family.ts index 1460bee..e764be6 100644 --- a/src/proxy/middleware/request/onproxyreq/check-model-family.ts +++ b/src/proxy/middleware/request/onproxyreq/check-model-family.ts @@ -1,13 +1,14 @@ import { HPMRequestCallback } from "../index"; import { config } from "../../../../config"; +import { ForbiddenError } from "../../../../shared/errors"; import { getModelFamilyForRequest } from "../../../../shared/models"; /** * Ensures the selected model family is enabled by the proxy configuration. **/ -export const checkModelFamily: HPMRequestCallback = (proxyReq, req) => { +export const checkModelFamily: HPMRequestCallback = (_proxyReq, req, res) => { const family = getModelFamilyForRequest(req); if (!config.allowedModelFamilies.includes(family)) { - throw new Error(`Model family ${family} is not permitted on this proxy`); + throw new ForbiddenError(`Model family '${family}' is not enabled on this proxy`); } }; diff --git a/src/proxy/middleware/request/preprocessors/set-api-format.ts b/src/proxy/middleware/request/preprocessors/set-api-format.ts index 5b2d2c7..11a0298 100644 --- a/src/proxy/middleware/request/preprocessors/set-api-format.ts +++ b/src/proxy/middleware/request/preprocessors/set-api-format.ts @@ -1,13 +1,14 @@ import { Request } from "express"; -import { APIFormat, LLMService } from "../../../../shared/key-management"; +import { APIFormat } from "../../../../shared/key-management"; +import { LLMService } from "../../../../shared/models"; import { RequestPreprocessor } from "../index"; export const setApiFormat = (api: { inApi: Request["inboundApi"]; outApi: APIFormat; - service: LLMService, + service: LLMService; }): RequestPreprocessor => { - return function configureRequestApiFormat (req) { + return function configureRequestApiFormat(req) { req.inboundApi = api.inApi; req.outboundApi = api.outApi; req.service = api.service; diff --git a/src/server.ts b/src/server.ts index 2b9e708..9f0161d 100644 --- a/src/server.ts +++ b/src/server.ts @@ -12,7 +12,8 @@ import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir"; import { keyPool } from "./shared/key-management"; import { adminRouter } from "./admin/routes"; import { proxyRouter } from "./proxy/routes"; -import { handleInfoPage } from "./info-page"; +import { handleInfoPage, renderPage } from "./info-page"; +import { buildInfo } from "./service-info"; import { logQueue } from "./shared/prompt-logging"; import { start as startRequestQueue } from "./proxy/queue"; import { init as initUserStore } from "./shared/users/user-store"; @@ -67,13 +68,14 @@ app.get("/health", (_req, res) => res.sendStatus(200)); app.use(cors()); app.use(checkOrigin); -// routes app.get("/", handleInfoPage); +app.get("/status", (req, res) => { + res.json(buildInfo(req.protocol + "://" + req.get("host"), false)); +}); app.use("/admin", adminRouter); app.use("/proxy", proxyRouter); app.use("/user", userRouter); -// 500 and 404 app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => { if (err.status) { res.status(err.status).json({ error: err.message }); diff --git a/src/service-info.ts b/src/service-info.ts new file mode 100644 index 0000000..add8a9d --- /dev/null +++ b/src/service-info.ts @@ -0,0 +1,417 @@ +/** Calculates and returns stats about the service. */ +import { config, listConfig } from "./config"; +import { + AnthropicKey, + AwsBedrockKey, + AzureOpenAIKey, + GoogleAIKey, + keyPool, + OpenAIKey, +} from "./shared/key-management"; +import { + AnthropicModelFamily, + assertIsKnownModelFamily, + AwsBedrockModelFamily, + AzureOpenAIModelFamily, + GoogleAIModelFamily, + LLM_SERVICES, + LLMService, + MODEL_FAMILY_SERVICE, + ModelFamily, + OpenAIModelFamily, +} from "./shared/models"; +import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats"; +import { getUniqueIps } from "./proxy/rate-limit"; +import { assertNever } from "./shared/utils"; +import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue"; + +const CACHE_TTL = 2000; + +type KeyPoolKey = ReturnType[0]; +const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey => + k.service === "openai"; +const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey => + k.service === "azure"; +const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey => + k.service === "anthropic"; +const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey => + k.service === "google-ai"; +const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws"; + +/** Stats aggregated across all keys for a given service. */ +type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs"; +/** Stats aggregated across all keys for a given model family. */ +type ModelAggregates = { + active: number; + trial?: number; + revoked?: number; + overQuota?: number; + pozzed?: number; + awsLogged?: number; + queued: number; + queueTime: string; + tokens: number; +}; +/** All possible combinations of model family and aggregate type. */ +type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`; + +type AllStats = { + proompts: number; + tokens: number; + tokenCost: number; +} & { [modelFamily in ModelFamily]?: ModelAggregates } & { + [service in LLMService as `${service}__${ServiceAggregate}`]?: number; +}; + +type BaseFamilyInfo = { + usage?: string; + activeKeys: number; + revokedKeys?: number; + proomptersInQueue?: number; + estimatedQueueTime?: string; +}; +type OpenAIInfo = BaseFamilyInfo & { + trialKeys?: number; + overQuotaKeys?: number; +}; +type AnthropicInfo = BaseFamilyInfo & { pozzedKeys?: number }; +type AwsInfo = BaseFamilyInfo & { privacy?: string }; + +// prettier-ignore +export type ServiceInfo = { + uptime: number; + endpoints: { + openai?: string; + openai2?: string; + "openai-image"?: string; + anthropic?: string; + "google-ai"?: string; + aws?: string; + azure?: string; + }; + proompts?: number; + tookens?: string; + proomptersNow?: number; + status?: string; + config: ReturnType; + build: string; +} & { [f in OpenAIModelFamily]?: OpenAIInfo } + & { [f in AnthropicModelFamily]?: AnthropicInfo; } + & { [f in AwsBedrockModelFamily]?: AwsInfo } + & { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; } + & { [f in GoogleAIModelFamily]?: BaseFamilyInfo }; + +// https://stackoverflow.com/a/66661477 +// type DeepKeyOf = ( +// [T] extends [never] +// ? "" +// : T extends object +// ? { +// [K in Exclude]: `${K}${DotPrefix>}`; +// }[Exclude] +// : "" +// ) extends infer D +// ? Extract +// : never; +// type DotPrefix = T extends "" ? "" : `.${T}`; +// type ServiceInfoPath = `{${DeepKeyOf}}`; + +const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { + openai: { + openai: `%BASE%/openai`, + openai2: `%BASE%/openai/turbo-instruct`, + "openai-image": `%BASE%/openai-image`, + }, + anthropic: { + anthropic: `%BASE%/anthropic`, + }, + "google-ai": { + "google-ai": `%BASE%/google-ai`, + }, + aws: { + aws: `%BASE%/aws/claude`, + }, + azure: { + azure: `%BASE%/azure/openai`, + }, +}; + +const modelStats = new Map(); +const serviceStats = new Map(); + +let cachedInfo: ServiceInfo | undefined; +let cacheTime = 0; + +export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo { + if (cacheTime + CACHE_TTL > Date.now()) return cachedInfo!; + + const keys = keyPool.list(); + const accessibleFamilies = new Set( + keys + .flatMap((k) => k.modelFamilies) + .filter((f) => config.allowedModelFamilies.includes(f)) + .concat("turbo") + ); + + serviceStats.clear(); + keys.forEach(addKeyToAggregates); + + const endpoints = getEndpoints(baseUrl, accessibleFamilies); + const trafficStats = getTrafficStats(); + const { serviceInfo, modelFamilyInfo } = + getServiceModelStats(accessibleFamilies); + const status = getStatus(); + + if (config.staticServiceInfo && !forAdmin) { + delete trafficStats.proompts; + delete trafficStats.tookens; + delete trafficStats.proomptersNow; + for (const family of Object.keys(modelFamilyInfo)) { + assertIsKnownModelFamily(family); + delete modelFamilyInfo[family]?.proomptersInQueue; + delete modelFamilyInfo[family]?.estimatedQueueTime; + delete modelFamilyInfo[family]?.usage; + } + } + + return (cachedInfo = { + uptime: Math.floor(process.uptime()), + endpoints, + ...trafficStats, + ...serviceInfo, + status, + ...modelFamilyInfo, + config: listConfig(), + build: process.env.BUILD_INFO || "dev", + }); +} + +function getStatus() { + if (!config.checkKeys) return "Key checking is disabled."; + + let unchecked = 0; + for (const service of LLM_SERVICES) { + unchecked += serviceStats.get(`${service}__uncheckedKeys`) || 0; + } + + return unchecked ? `Checking ${unchecked} keys...` : undefined; +} + +function getEndpoints(baseUrl: string, accessibleFamilies: Set) { + const endpoints: Record = {}; + for (const service of LLM_SERVICES) { + for (const [name, url] of Object.entries(SERVICE_ENDPOINTS[service])) { + endpoints[name] = url.replace("%BASE%", baseUrl); + } + + if (service === "openai" && !accessibleFamilies.has("dall-e")) { + delete endpoints["openai-image"]; + } + } + return endpoints; +} + +type TrafficStats = Pick; + +function getTrafficStats(): TrafficStats { + const tokens = serviceStats.get("tokens") || 0; + const tokenCost = serviceStats.get("tokenCost") || 0; + return { + proompts: serviceStats.get("proompts") || 0, + tookens: `${prettyTokens(tokens)}${getCostSuffix(tokenCost)}`, + ...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}), + }; +} + +function getServiceModelStats(accessibleFamilies: Set) { + const serviceInfo: { + [s in LLMService as `${s}${"Keys" | "Orgs"}`]?: number; + } = {}; + const modelFamilyInfo: { [f in ModelFamily]?: BaseFamilyInfo } = {}; + + for (const service of LLM_SERVICES) { + const hasKeys = serviceStats.get(`${service}__keys`) || 0; + if (!hasKeys) continue; + + serviceInfo[`${service}Keys`] = hasKeys; + accessibleFamilies.forEach((f) => { + if (MODEL_FAMILY_SERVICE[f] === service) { + modelFamilyInfo[f] = getInfoForFamily(f); + } + }); + + if (service === "openai" && config.checkKeys) { + serviceInfo.openaiOrgs = getUniqueOpenAIOrgs(keyPool.list()); + } + } + return { serviceInfo, modelFamilyInfo }; +} + +function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) { + const orgIds = new Set( + keys.filter((k) => k.service === "openai").map((k: any) => k.organizationId) + ); + return orgIds.size; +} + +function increment( + map: Map, + key: T, + delta = 1 +) { + map.set(key, (map.get(key) || 0) + delta); +} + +function addKeyToAggregates(k: KeyPoolKey) { + increment(serviceStats, "proompts", k.promptCount); + increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0); + increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0); + increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0); + increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0); + increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0); + + let sumTokens = 0; + let sumCost = 0; + + switch (k.service) { + case "openai": + if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type"); + increment( + serviceStats, + "openai__uncheckedKeys", + Boolean(k.lastChecked) ? 0 : 1 + ); + + k.modelFamilies.forEach((f) => { + const tokens = k[`${f}Tokens`]; + sumTokens += tokens; + sumCost += getTokenCostUsd(f, tokens); + increment(modelStats, `${f}__tokens`, tokens); + increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); + increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); + increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0); + increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0); + }); + break; + case "azure": + if (!keyIsAzureKey(k)) throw new Error("Invalid key type"); + k.modelFamilies.forEach((f) => { + const tokens = k[`${f}Tokens`]; + sumTokens += tokens; + sumCost += getTokenCostUsd(f, tokens); + increment(modelStats, `${f}__tokens`, tokens); + increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); + increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); + }); + break; + case "anthropic": { + if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type"); + const family = "claude"; + sumTokens += k.claudeTokens; + sumCost += getTokenCostUsd(family, k.claudeTokens); + increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); + increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); + increment(modelStats, `${family}__tokens`, k.claudeTokens); + increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0); + increment( + serviceStats, + "anthropic__uncheckedKeys", + Boolean(k.lastChecked) ? 0 : 1 + ); + break; + } + case "google-ai": { + if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type"); + const family = "gemini-pro"; + sumTokens += k["gemini-proTokens"]; + sumCost += getTokenCostUsd(family, k["gemini-proTokens"]); + increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); + increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); + increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]); + break; + } + case "aws": { + if (!keyIsAwsKey(k)) throw new Error("Invalid key type"); + const family = "aws-claude"; + sumTokens += k["aws-claudeTokens"]; + sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]); + increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); + increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); + increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]); + + // Ignore revoked keys for aws logging stats, but include keys where the + // logging status is unknown. + const countAsLogged = + k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled"; + increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0); + + break; + } + default: + assertNever(k.service); + } + + increment(serviceStats, "tokens", sumTokens); + increment(serviceStats, "tokenCost", sumCost); +} + +function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { + const tokens = modelStats.get(`${family}__tokens`) || 0; + const cost = getTokenCostUsd(family, tokens); + let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = { + usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`, + activeKeys: modelStats.get(`${family}__active`) || 0, + revokedKeys: modelStats.get(`${family}__revoked`) || 0, + }; + + // Add service-specific stats to the info object. + if (config.checkKeys) { + const service = MODEL_FAMILY_SERVICE[family]; + switch (service) { + case "openai": + info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0; + info.trialKeys = modelStats.get(`${family}__trial`) || 0; + + // Delete trial/revoked keys for non-turbo families. + // Trials are turbo 99% of the time, and if a key is invalid we don't + // know what models it might have had assigned to it. + if (family !== "turbo") { + delete info.trialKeys; + delete info.revokedKeys; + } + break; + case "anthropic": + info.pozzedKeys = modelStats.get(`${family}__pozzed`) || 0; + break; + case "aws": + const logged = modelStats.get(`${family}__awsLogged`) || 0; + const logMsg = config.allowAwsLogging + ? `${logged} active keys are potentially logged.` + : `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`; + info.privacy = logMsg; + break; + } + } + + // Add queue stats to the info object. + const queue = getQueueInformation(family); + info.proomptersInQueue = queue.proomptersInQueue; + info.estimatedQueueTime = queue.estimatedQueueTime; + + return info; +} + +/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */ +function getQueueInformation(partition: ModelFamily) { + const waitMs = getEstimatedWaitTime(partition); + const waitTime = + waitMs < 60000 + ? `${Math.round(waitMs / 1000)}sec` + : `${Math.round(waitMs / 60000)}min, ${Math.round( + (waitMs % 60000) / 1000 + )}sec`; + return { + proomptersInQueue: getQueueLength(partition), + estimatedQueueTime: waitMs > 2000 ? waitTime : "no wait", + }; +} diff --git a/src/types/custom.d.ts b/src/shared/custom.d.ts similarity index 85% rename from src/types/custom.d.ts rename to src/shared/custom.d.ts index 6876ad7..1f98986 100644 --- a/src/types/custom.d.ts +++ b/src/shared/custom.d.ts @@ -1,8 +1,10 @@ +// noinspection JSUnusedGlobalSymbols,ES6UnusedImports + import type { HttpRequest } from "@smithy/types"; import { Express } from "express-serve-static-core"; -import { APIFormat, Key, LLMService } from "../shared/key-management"; -import { User } from "../shared/users/schema"; -import { ModelFamily } from "../shared/models"; +import { APIFormat, Key } from "./key-management"; +import { User } from "./users/schema"; +import { LLMService, ModelFamily } from "./models"; declare global { namespace Express { diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 2cdc373..c7d1141 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -1,10 +1,10 @@ +import type { LLMService, ModelFamily } from "../models"; import { OpenAIModel } from "./openai/provider"; import { AnthropicModel } from "./anthropic/provider"; import { GoogleAIModel } from "./google-ai/provider"; import { AwsBedrockModel } from "./aws/provider"; import { AzureOpenAIModel } from "./azure/provider"; import { KeyPool } from "./key-pool"; -import type { ModelFamily } from "../models"; /** The request and response format used by a model's API. */ export type APIFormat = @@ -13,13 +13,6 @@ export type APIFormat = | "google-ai" | "openai-text" | "openai-image"; -/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */ -export type LLMService = - | "openai" - | "anthropic" - | "google-ai" - | "aws" - | "azure"; export type Model = | OpenAIModel | AnthropicModel diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 3f7d6c5..7726857 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -4,13 +4,12 @@ import os from "os"; import schedule from "node-schedule"; import { config } from "../../config"; import { logger } from "../../logger"; -import { Key, Model, KeyProvider, LLMService } from "./index"; +import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models"; +import { Key, Model, KeyProvider } from "./index"; import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { GoogleAIKeyProvider } from "./google-ai/provider"; import { AwsBedrockKeyProvider } from "./aws/provider"; -import { ModelFamily } from "../models"; -import { assertNever } from "../utils"; import { AzureOpenAIKeyProvider } from "./azure/provider"; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; @@ -82,7 +81,7 @@ export class KeyPool { } public getLockoutPeriod(family: ModelFamily): number { - const service = this.getServiceForModelFamily(family); + const service = MODEL_FAMILY_SERVICE[family]; return this.getKeyProvider(service).getLockoutPeriod(family); } @@ -132,30 +131,6 @@ export class KeyPool { throw new Error(`Unknown service for model '${model}'`); } - private getServiceForModelFamily(modelFamily: ModelFamily): LLMService { - switch (modelFamily) { - case "gpt4": - case "gpt4-32k": - case "gpt4-turbo": - case "turbo": - case "dall-e": - return "openai"; - case "claude": - return "anthropic"; - case "gemini-pro": - return "google-ai"; - case "aws-claude": - return "aws"; - case "azure-turbo": - case "azure-gpt4": - case "azure-gpt4-32k": - case "azure-gpt4-turbo": - return "azure"; - default: - assertNever(modelFamily); - } - } - private getKeyProvider(service: LLMService): KeyProvider { return this.keyProviders.find((provider) => provider.service === service)!; } diff --git a/src/shared/models.ts b/src/shared/models.ts index a612848..75db61d 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -1,8 +1,14 @@ -// Don't import anything here, this is imported by config.ts +// Don't import any other project files here as this is one of the first modules +// loaded and it will cause circular imports. import pino from "pino"; import type { Request } from "express"; -import { assertNever } from "./utils"; + +/** + * The service that a model is hosted on. Distinct from `APIFormat` because some + * services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure). + */ +export type LLMService = "openai" | "anthropic" | "google-ai" | "aws" | "azure"; export type OpenAIModelFamily = | "turbo" @@ -41,6 +47,10 @@ export const MODEL_FAMILIES = (( "azure-gpt4-turbo", ] as const); +export const LLM_SERVICES = (( + arr: A & ([LLMService] extends [A[number]] ? unknown : never) +) => arr)(["openai", "anthropic", "google-ai", "aws", "azure"] as const); + export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { "^gpt-4-1106(-preview)?$": "gpt4-turbo", "^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo", @@ -53,6 +63,23 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { "^dall-e-\\d{1}$": "dall-e", }; +export const MODEL_FAMILY_SERVICE: { + [f in ModelFamily]: LLMService; +} = { + turbo: "openai", + gpt4: "openai", + "gpt4-turbo": "openai", + "gpt4-32k": "openai", + "dall-e": "openai", + claude: "anthropic", + "aws-claude": "aws", + "azure-turbo": "azure", + "azure-gpt4": "azure", + "azure-gpt4-32k": "azure", + "azure-gpt4-turbo": "azure", + "gemini-pro": "google-ai", +}; + pino({ level: "debug" }).child({ module: "startup" }); export function getOpenAIModelFamily( @@ -138,3 +165,7 @@ export function getModelFamilyForRequest(req: Request): ModelFamily { return (req.modelFamily = modelFamily); } + +function assertNever(x: never): never { + throw new Error(`Called assertNever with argument ${x}.`); +} diff --git a/src/shared/stats.ts b/src/shared/stats.ts index f0b05e5..eedf32d 100644 --- a/src/shared/stats.ts +++ b/src/shared/stats.ts @@ -1,3 +1,4 @@ +import { config } from "../config"; import { ModelFamily } from "./models"; // technically slightly underestimates, because completion tokens cost more @@ -40,3 +41,8 @@ export function prettyTokens(tokens: number): string { return (tokens / 1000000000).toFixed(3) + "b"; } } + +export function getCostSuffix(cost: number) { + if (!config.showTokenCosts) return ""; + return ` ($${cost.toFixed(2)})`; +} diff --git a/tsconfig.json b/tsconfig.json index 13a4926..a1762f4 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -15,5 +15,5 @@ }, "include": ["src"], "exclude": ["node_modules"], - "files": ["src/types/custom.d.ts"] + "files": ["src/shared/custom.d.ts"] }