From 6a908b09cbb2b4ccecd8fe905f863eba228018e5 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Thu, 12 Sep 2024 22:55:45 -0500 Subject: [PATCH] adds preliminary openai o1 support and some improvements to openai keychecker --- .env.example | 12 +- src/config.ts | 4 +- src/info-page.ts | 4 + src/proxy/azure.ts | 60 ++----- .../preprocessors/validate-context-size.ts | 4 + src/proxy/middleware/response/index.ts | 26 ++- src/proxy/openai-image.ts | 4 +- src/proxy/openai.ts | 157 +++++++----------- src/proxy/queue.ts | 4 +- src/shared/api-schemas/openai.ts | 7 + src/shared/key-management/azure/checker.ts | 1 + src/shared/key-management/azure/provider.ts | 4 + src/shared/key-management/openai/checker.ts | 43 ++--- src/shared/key-management/openai/provider.ts | 105 ++++-------- src/shared/key-management/prioritize-keys.ts | 4 +- src/shared/models.ts | 12 ++ src/shared/stats.ts | 12 ++ src/shared/tokenization/tokenizer.ts | 2 + 18 files changed, 197 insertions(+), 268 deletions(-) diff --git a/.env.example b/.env.example index 968ba6d..99666e4 100644 --- a/.env.example +++ b/.env.example @@ -41,13 +41,13 @@ NODE_ENV=production # Which model types users are allowed to access. # The following model families are recognized: -# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus -# | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small -# | mistral-medium | mistral-large | aws-claude | aws-claude-opus | gcp-claude -# | gcp-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k -# | azure-gpt4-turbo | azure-gpt4o | azure-dall-e +# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | o1 | dall-e | claude +# | claude-opus | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | +# | mistral-small | mistral-medium | mistral-large | aws-claude | +# | aws-claude-opus | gcp-claude | gcp-claude-opus | azure-turbo | azure-gpt4 +# | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-o1 | azure-dall-e -# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'. +# By default, all models are allowed except for dall-e and o1. # To allow DALL-E image generation, uncomment the line below and add 'dall-e' or # 'azure-dall-e' to the list of allowed model families. # ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,gcp-claude,gcp-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o diff --git a/src/config.ts b/src/config.ts index 2ae591e..e7651e5 100644 --- a/src/config.ts +++ b/src/config.ts @@ -790,5 +790,7 @@ function parseCsv(val: string): string[] { } function getDefaultModelFamilies(): ModelFamily[] { - return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[]; + return MODEL_FAMILIES.filter( + (f) => !f.includes("dall-e") && !f.includes("o1") + ) as ModelFamily[]; } diff --git a/src/info-page.ts b/src/info-page.ts index 580c4b6..8f82da4 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -17,6 +17,8 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { "gpt4-32k": "GPT-4 32k", "gpt4-turbo": "GPT-4 Turbo", gpt4o: "GPT-4o", + o1: "OpenAI o1", + "o1-mini": "OpenAI o1 mini", "dall-e": "DALL-E", claude: "Claude (Sonnet)", "claude-opus": "Claude (Opus)", @@ -40,6 +42,8 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { "azure-gpt4-32k": "Azure GPT-4 32k", "azure-gpt4-turbo": "Azure GPT-4 Turbo", "azure-gpt4o": "Azure GPT-4o", + "azure-o1": "Azure o1", + "azure-o1-mini": "Azure o1 mini", "azure-dall-e": "Azure DALL-E", }; diff --git a/src/proxy/azure.ts b/src/proxy/azure.ts index e8a6155..c7fe20d 100644 --- a/src/proxy/azure.ts +++ b/src/proxy/azure.ts @@ -1,14 +1,8 @@ import { RequestHandler, Router } from "express"; import { createProxyMiddleware } from "http-proxy-middleware"; import { config } from "../config"; -import { keyPool } from "../shared/key-management"; -import { - AzureOpenAIModelFamily, - getAzureOpenAIModelFamily, - ModelFamily, -} from "../shared/models"; import { logger } from "../logger"; -import { KNOWN_OPENAI_MODELS } from "./openai"; +import { generateModelList } from "./openai"; import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit"; import { handleProxyError } from "./middleware/common"; @@ -26,48 +20,18 @@ import { let modelsCache: any = null; let modelsCacheTime = 0; -function getModelsResponse() { - if (new Date().getTime() - modelsCacheTime < 1000 * 60) { - return modelsCache; - } - - let available = new Set(); - for (const key of keyPool.list()) { - if (key.isDisabled || key.service !== "azure") continue; - key.modelFamilies.forEach((family) => - available.add(family as AzureOpenAIModelFamily) - ); - } - const allowed = new Set(config.allowedModelFamilies); - available = new Set([...available].filter((x) => allowed.has(x))); - - const models = KNOWN_OPENAI_MODELS.map((id) => ({ - id, - object: "model", - created: new Date().getTime(), - owned_by: "azure", - permission: [ - { - id: "modelperm-" + id, - object: "model_permission", - created: new Date().getTime(), - organization: "*", - group: null, - is_blocking: false, - }, - ], - root: id, - parent: null, - })).filter((model) => available.has(getAzureOpenAIModelFamily(model.id))); - - modelsCache = { object: "list", data: models }; - modelsCacheTime = new Date().getTime(); - - return modelsCache; -} - const handleModelRequest: RequestHandler = (_req, res) => { - res.status(200).json(getModelsResponse()); + if (new Date().getTime() - modelsCacheTime < 1000 * 60) { + return res.status(200).json(modelsCache); + } + + if (!config.azureCredentials) return { object: "list", data: [] }; + + const result = generateModelList("azure"); + + modelsCache = { object: "list", data: result }; + modelsCacheTime = new Date().getTime(); + res.status(200).json(modelsCache); }; const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async ( diff --git a/src/proxy/middleware/request/preprocessors/validate-context-size.ts b/src/proxy/middleware/request/preprocessors/validate-context-size.ts index 9d7cab5..fa0c579 100644 --- a/src/proxy/middleware/request/preprocessors/validate-context-size.ts +++ b/src/proxy/middleware/request/preprocessors/validate-context-size.ts @@ -68,6 +68,10 @@ export const validateContextSize: RequestPreprocessor = async (req) => { modelMax = 131072; } else if (model.match(/^gpt-4(-\d{4})?-vision(-preview)?$/)) { modelMax = 131072; + } else if (model.match(/^o1-mini(-\d{4}-\d{2}-\d{2})?$/)) { + modelMax = 128000; + } else if (model.match(/^o1(-preview)?(-\d{4}-\d{2}-\d{2})?$/)) { + modelMax = 128000; } else if (model.match(/gpt-3.5-turbo/)) { modelMax = 16384; } else if (model.match(/gpt-4-32k/)) { diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index a6cfd62..c6a6144 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -212,8 +212,12 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( delete errorPayload.message; } else if (service === "gcp") { // Try to standardize the error format for GCP - if (errorPayload.error?.code) { // GCP Error - errorPayload.error = { message: errorPayload.error.message, type: errorPayload.error.status || errorPayload.error.code }; + if (errorPayload.error?.code) { + // GCP Error + errorPayload.error = { + message: errorPayload.error.message, + type: errorPayload.error.status || errorPayload.error.code, + }; } } @@ -231,7 +235,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( // same 429 billing error that other models return. await handleOpenAIRateLimitError(req, errorPayload); } else { - errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`; + errorPayload.proxy_note = `The upstream API rejected the request. Check the error message for details.`; } break; case "anthropic": @@ -293,8 +297,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( errorPayload.proxy_note = `Received 403 error. Key may be invalid.`; } return; - case "mistral-ai": - case "gcp": + case "mistral-ai": + case "gcp": keyPool.disable(req.key!, "revoked"); errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`; return; @@ -688,15 +692,23 @@ const countResponseTokens: ProxyResHandlerWithBody = async ( const completion = getCompletionFromBody(req, body); const tokens = await countTokens({ req, completion, service }); + if (req.service === "openai" || req.service === "azure") { + // O1 consumes (a significant amount of) invisible tokens for the chain- + // of-thought reasoning. We have no way to count these other than to check + // the response body. + tokens.reasoning_tokens = + body.usage?.completion_tokens_details?.reasoning_tokens; + } + req.log.debug( - { service, tokens, prevOutputTokens: req.outputTokens }, + { service, prevOutputTokens: req.outputTokens, tokens }, `Counted tokens for completion` ); if (req.tokenizerInfo) { req.tokenizerInfo.completion_tokens = tokens; } - req.outputTokens = tokens.token_count; + req.outputTokens = tokens.token_count + (tokens.reasoning_tokens ?? 0); } catch (error) { req.log.warn( error, diff --git a/src/proxy/openai-image.ts b/src/proxy/openai-image.ts index a50160e..a1fab53 100644 --- a/src/proxy/openai-image.ts +++ b/src/proxy/openai-image.ts @@ -26,7 +26,9 @@ const handleModelRequest: RequestHandler = (_req, res) => { if (new Date().getTime() - modelListValid < 1000 * 60) { return res.status(200).json(modelListCache); } - const result = generateModelList(KNOWN_MODELS); + const result = generateModelList().filter((m: { id: string }) => + KNOWN_MODELS.includes(m.id) + ); modelListCache = { object: "list", data: result }; modelListValid = new Date().getTime(); res.status(200).json(modelListCache); diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index cd80f13..8e5d630 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -1,12 +1,8 @@ -import { RequestHandler, Router } from "express"; +import { Request, RequestHandler, Router } from "express"; import { createProxyMiddleware } from "http-proxy-middleware"; import { config } from "../config"; -import { keyPool, OpenAIKey } from "../shared/key-management"; -import { - getOpenAIModelFamily, - ModelFamily, - OpenAIModelFamily, -} from "../shared/models"; +import { AzureOpenAIKey, keyPool, OpenAIKey } from "../shared/key-management"; +import { getOpenAIModelFamily } from "../shared/models"; import { logger } from "../logger"; import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit"; @@ -27,103 +23,66 @@ import { } from "./middleware/response"; // https://platform.openai.com/docs/models/overview -export const KNOWN_OPENAI_MODELS = [ - // GPT4o - "gpt-4o", - "gpt-4o-2024-05-13", - "gpt-4o-2024-08-06", - // GPT4o Mini - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - // GPT4o (ChatGPT) - "chatgpt-4o-latest", - // GPT4 Turbo (superceded by GPT4o) - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", // gpt4-turbo stable, with vision - "gpt-4-turbo-preview", // alias for latest turbo preview - "gpt-4-0125-preview", // gpt4-turbo preview 2 - "gpt-4-1106-preview", // gpt4-turbo preview 1 - // Launch GPT4 - "gpt-4", - "gpt-4-0613", - "gpt-4-0314", // legacy - // GPT3.5 Turbo (superceded by GPT4o Mini) - "gpt-3.5-turbo", - "gpt-3.5-turbo-0125", // latest turbo - "gpt-3.5-turbo-1106", // older turbo - // Text Completion - "gpt-3.5-turbo-instruct", - "gpt-3.5-turbo-instruct-0914", - // Embeddings - "text-embedding-ada-002", - // Known deprecated models - "gpt-4-32k", // alias for 0613 - "gpt-4-32k-0314", // EOL 2025-06-06 - "gpt-4-32k-0613", // EOL 2025-06-06 - "gpt-4-vision-preview", // EOL 2024-12-06 - "gpt-4-1106-vision-preview", // EOL 2024-12-06 - "gpt-3.5-turbo-0613", // EOL 2024-09-13 - "gpt-3.5-turbo-0301", // not on the website anymore, maybe unavailable - "gpt-3.5-turbo-16k", // alias for 0613 - "gpt-3.5-turbo-16k-0613", // EOL 2024-09-13 -]; - let modelsCache: any = null; let modelsCacheTime = 0; -export function generateModelList(models = KNOWN_OPENAI_MODELS) { - // Get available families and snapshots - let availableFamilies = new Set(); - const availableSnapshots = new Set(); - for (const key of keyPool.list()) { - if (key.isDisabled || key.service !== "openai") continue; - const asOpenAIKey = key as OpenAIKey; - asOpenAIKey.modelFamilies.forEach((f) => availableFamilies.add(f)); - asOpenAIKey.modelSnapshots.forEach((s) => availableSnapshots.add(s)); - } +export function generateModelList(service: "openai" | "azure") { + const keys = keyPool + .list() + .filter((k) => k.service === service && !k.isDisabled) as + | OpenAIKey[] + | AzureOpenAIKey[]; + if (keys.length === 0) return []; - // Remove disabled families - const allowed = new Set(config.allowedModelFamilies); - availableFamilies = new Set( - [...availableFamilies].filter((x) => allowed.has(x)) + const allowedModelFamilies = new Set(config.allowedModelFamilies); + const modelFamilies = new Set( + keys + .flatMap((k) => k.modelFamilies) + .filter((f) => allowedModelFamilies.has(f)) ); - return models - .map((id) => ({ - id, - object: "model", - created: new Date().getTime(), - owned_by: "openai", - permission: [ - { - id: "modelperm-" + id, - object: "model_permission", - created: new Date().getTime(), - organization: "*", - group: null, - is_blocking: false, - }, - ], - root: id, - parent: null, - })) - .filter((model) => { - // First check if the family is available - const hasFamily = availableFamilies.has(getOpenAIModelFamily(model.id)); - if (!hasFamily) return false; + const modelIds = new Set( + keys + .flatMap((k) => k.modelIds) + .filter((id) => { + const allowed = modelFamilies.has(getOpenAIModelFamily(id)); + const known = ["gpt", "o1", "dall-e", "text-embedding-ada-002"].some( + (prefix) => id.startsWith(prefix) + ); + const isFinetune = id.includes("ft"); + return allowed && known && !isFinetune; + }) + ); - // Then for snapshots, ensure the specific snapshot is available - const isSnapshot = model.id.match(/-\d{4}(-preview)?$/); - if (!isSnapshot) return true; - return availableSnapshots.has(model.id); - }); + return Array.from(modelIds).map((id) => ({ + id, + object: "model", + created: new Date().getTime(), + owned_by: service, + permission: [ + { + id: "modelperm-" + id, + object: "model_permission", + created: new Date().getTime(), + organization: "*", + group: null, + is_blocking: false, + }, + ], + root: id, + parent: null, + })); } const handleModelRequest: RequestHandler = (_req, res) => { if (new Date().getTime() - modelsCacheTime < 1000 * 60) { return res.status(200).json(modelsCache); } - const result = generateModelList(); + + if (!config.openaiKey) return { object: "list", data: [] }; + + const result = generateModelList("openai"); + modelsCache = { object: "list", data: result }; modelsCacheTime = new Date().getTime(); res.status(200).json(modelsCache); @@ -242,11 +201,10 @@ openaiRouter.post( openaiRouter.post( "/v1/chat/completions", ipLimiter, - createPreprocessorMiddleware({ - inApi: "openai", - outApi: "openai", - service: "openai", - }), + createPreprocessorMiddleware( + { inApi: "openai", outApi: "openai", service: "openai" }, + { afterTransform: [fixupMaxTokens] } + ), openaiProxy ); // Embeddings endpoint. @@ -257,4 +215,11 @@ openaiRouter.post( openaiEmbeddingsProxy ); +function fixupMaxTokens(req: Request) { + if (!req.body.max_completion_tokens) { + req.body.max_completion_tokens = req.body.max_tokens; + } + delete req.body.max_tokens; +} + export const openai = openaiRouter; diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index 3fbcc29..127094d 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -35,14 +35,12 @@ const log = logger.child({ module: "request-queue" }); const USER_CONCURRENCY_LIMIT = parseInt( process.env.USER_CONCURRENCY_LIMIT ?? "1" ); -/** Maximum number of queue slots for Agnai.chat requests. */ -const AGNAI_CONCURRENCY_LIMIT = USER_CONCURRENCY_LIMIT * 5; const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512"); const MAX_HEARTBEAT_SIZE = 1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024"); const HEARTBEAT_INTERVAL = 1000 * parseInt(process.env.HEARTBEAT_INTERVAL_SEC ?? "5"); -const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "50"); +const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "150"); const PAYLOAD_SCALE_FACTOR = parseFloat( process.env.PAYLOAD_SCALE_FACTOR ?? "6" ); diff --git a/src/shared/api-schemas/openai.ts b/src/shared/api-schemas/openai.ts index 58b3ceb..d4be629 100644 --- a/src/shared/api-schemas/openai.ts +++ b/src/shared/api-schemas/openai.ts @@ -54,6 +54,13 @@ export const OpenAIV1ChatCompletionSchema = z .nullish() .default(Math.min(OPENAI_OUTPUT_MAX, 16384)) .transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)), + // max_completion_tokens replaces max_tokens in the OpenAI API. + // for backwards compatibility, we accept both and move the value in + // max_tokens to max_completion_tokens in proxy middleware. + max_completion_tokens: z.coerce + .number() + .int() + .optional(), frequency_penalty: z.number().optional().default(0), presence_penalty: z.number().optional().default(0), logit_bias: z.any().optional(), diff --git a/src/shared/key-management/azure/checker.ts b/src/shared/key-management/azure/checker.ts index 6beeb14..68b5980 100644 --- a/src/shared/key-management/azure/checker.ts +++ b/src/shared/key-management/azure/checker.ts @@ -137,6 +137,7 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase { } const family = getAzureOpenAIModelFamily(data.model); + this.updateKey(key.hash, { modelIds: [data.model] }); // Azure returns "gpt-4" even for GPT-4 Turbo, so we need further checks. // Otherwise we can use the model family Azure returned. diff --git a/src/shared/key-management/azure/provider.ts b/src/shared/key-management/azure/provider.ts index 28439fa..f09ca5e 100644 --- a/src/shared/key-management/azure/provider.ts +++ b/src/shared/key-management/azure/provider.ts @@ -18,6 +18,7 @@ export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage { readonly service: "azure"; readonly modelFamilies: AzureOpenAIModelFamily[]; contentFiltering: boolean; + modelIds: string[]; } /** @@ -72,7 +73,10 @@ export class AzureOpenAIKeyProvider implements KeyProvider { "azure-gpt4-32kTokens": 0, "azure-gpt4-turboTokens": 0, "azure-gpt4oTokens": 0, + "azure-o1Tokens": 0, + "azure-o1-miniTokens": 0, "azure-dall-eTokens": 0, + modelIds: [], }; this.keys.push(newKey); } diff --git a/src/shared/key-management/openai/checker.ts b/src/shared/key-management/openai/checker.ts index 481d0c4..8dac7f9 100644 --- a/src/shared/key-management/openai/checker.ts +++ b/src/shared/key-management/openai/checker.ts @@ -63,7 +63,7 @@ export class OpenAIKeyChecker extends KeyCheckerBase { key: key.hash, models: key.modelFamilies, trial: key.isTrial, - snapshots: key.modelSnapshots, + snapshots: key.modelIds, }, "Checked key." ); @@ -74,10 +74,11 @@ export class OpenAIKeyChecker extends KeyCheckerBase { ): Promise { const opts = { headers: OpenAIKeyChecker.getHeaders(key) }; const { data } = await axios.get(GET_MODELS_URL, opts); + const ids = new Set(); const families = new Set(); - const models = data.data.map(({ id }) => { + data.data.forEach(({ id }) => { + ids.add(id); families.add(getOpenAIModelFamily(id, "turbo")); - return id; }); // disable dall-e for trial keys due to very low per-day quota that tends to @@ -86,36 +87,12 @@ export class OpenAIKeyChecker extends KeyCheckerBase { families.delete("dall-e"); } - // as of 2023-11-18, many keys no longer return the dalle3 model but still - // have access to it via the api for whatever reason. - // if (families.has("dall-e") && !models.find(({ id }) => id === "dall-e-3")) { - // families.delete("dall-e"); - // } - - // as of January 2024, 0314 model snapshots are only available on keys which - // have used them in the past. these keys also seem to have 32k-0314 even - // though they don't have the base gpt-4-32k model alias listed. if a key - // has access to both 0314 models we will flag it as such and force add - // gpt4-32k to its model families. - if ( - ["gpt-4-0314", "gpt-4-32k-0314"].every((m) => models.find((n) => n === m)) - ) { - this.log.info({ key: key.hash }, "Added gpt4-32k to -0314 key."); - families.add("gpt4-32k"); - } - - // We want to update the key's model families here, but we don't want to - // update its `lastChecked` timestamp because we need to let the liveness - // check run before we can consider the key checked. - - const familiesArray = [...families]; - const keyFromPool = this.keys.find((k) => k.hash === key.hash)!; this.updateKey(key.hash, { - modelSnapshots: models.filter((m) => m.match(/-\d{4}(-preview)?$/)), - modelFamilies: familiesArray, - lastChecked: keyFromPool.lastChecked, + modelIds: Array.from(ids), + modelFamilies: Array.from(families), }); - return familiesArray; + + return key.modelFamilies; } private async maybeCreateOrganizationClones(key: OpenAIKey) { @@ -333,9 +310,11 @@ export class OpenAIKeyChecker extends KeyCheckerBase { } static getHeaders(key: OpenAIKey) { + const useOrg = !key.key.includes("svcacct"); return { Authorization: `Bearer ${key.key}`, - ...(key.organizationId && { "OpenAI-Organization": key.organizationId }), + ...(useOrg && + key.organizationId && { "OpenAI-Organization": key.organizationId }), }; } } diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 528f029..27176b1 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -3,12 +3,11 @@ import http from "http"; import { Key, KeyProvider } from "../index"; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { OpenAIKeyChecker } from "./checker"; import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models"; import { PaymentRequiredError } from "../../errors"; +import { OpenAIKeyChecker } from "./checker"; +import { prioritizeKeys } from "../prioritize-keys"; -// Flattening model families instead of using a nested object for easier -// cloning. type OpenAIKeyUsage = { [K in OpenAIModelFamily as `${K}Tokens`]: number; }; @@ -48,14 +47,10 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage { * tokens. */ rateLimitTokensReset: number; - /** - * This key's maximum request rate for GPT-4, per minute. - */ - gpt4Rpm: number; /** * Model snapshots available. */ - modelSnapshots: string[]; + modelIds: string[]; } export type OpenAIKeyUpdate = Omit< @@ -117,9 +112,10 @@ export class OpenAIKeyProvider implements KeyProvider { "gpt4-32kTokens": 0, "gpt4-turboTokens": 0, gpt4oTokens: 0, + "o1Tokens": 0, + "o1-miniTokens": 0, "dall-eTokens": 0, - gpt4Rpm: 0, - modelSnapshots: [], + modelIds: [], }; this.keys.push(newKey); } @@ -140,27 +136,14 @@ export class OpenAIKeyProvider implements KeyProvider { * Don't mutate returned keys, use a KeyPool method instead. **/ public list() { - return this.keys.map((key) => { - return Object.freeze({ - ...key, - key: undefined, - }); - }); + return this.keys.map((key) => Object.freeze({ ...key, key: undefined })); } public get(requestModel: string) { let model = requestModel; - // Special case for GPT-4-32k. Some keys have access to only gpt4-32k-0314 - // but not gpt-4-32k-0613, or its alias gpt-4-32k. Because we add a model - // family if a key has any snapshot, we need to dealias gpt-4-32k here so - // we can look for the specific snapshot. - // gpt-4-32k is superceded by gpt4-turbo so this shouldn't ever change. - if (model === "gpt-4-32k") model = "gpt-4-32k-0613"; - const neededFamily = getOpenAIModelFamily(model); const excludeTrials = model === "text-embedding-ada-002"; - const needsSnapshot = model.match(/-\d{4}(-preview)?$/); const availableKeys = this.keys.filter( // Allow keys which @@ -168,58 +151,22 @@ export class OpenAIKeyProvider implements KeyProvider { !key.isDisabled && // are not disabled key.modelFamilies.includes(neededFamily) && // have access to the model family we need (!excludeTrials || !key.isTrial) && // and are not trials if we don't want them - (!needsSnapshot || key.modelSnapshots.includes(model)) // and have the specific snapshot we need + (!config.checkKeys || key.modelIds.includes(model)) // and have the specific snapshot we need ); if (availableKeys.length === 0) { throw new PaymentRequiredError( - `No keys can fulfill request for ${model}` + `No OpenAI keys available for model ${model}` ); } - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. We ignore rate limits from >30 seconds ago - // b. If all keys were rate limited in the last minute, select the - // least recently rate limited key - // 2. Keys which are trials - // 3. Keys which do *not* have access to GPT-4-32k - // 4. Keys which have not been used in the longest time - - const now = Date.now(); - const rateLimitThreshold = 30 * 1000; - - const keysByPriority = availableKeys.sort((a, b) => { - // TODO: this isn't quite right; keys are briefly artificially rate- - // limited when they are selected, so this will deprioritize keys that - // may not actually be limited, simply because they were used recently. - // This should be adjusted to use a new `rateLimitedUntil` field instead - // of `rateLimitedAt`. - const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold; - const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - // Neither key is rate limited, continue - - if (a.isTrial && !b.isTrial) return -1; - if (!a.isTrial && b.isTrial) return 1; - // Neither or both keys are trials, continue - - const aHas32k = a.modelFamilies.includes("gpt4-32k"); - const bHas32k = b.modelFamilies.includes("gpt4-32k"); - if (aHas32k && !bHas32k) return 1; - if (!aHas32k && bHas32k) return -1; - // Neither or both keys have 32k, continue - - return a.lastUsed - b.lastUsed; - }); + const keysByPriority = prioritizeKeys( + availableKeys, + (a, b) => +a.isTrial - +b.isTrial + ); const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } @@ -273,6 +220,9 @@ export class OpenAIKeyProvider implements KeyProvider { * the request, or returns 0 if a key is ready immediately. */ public getLockoutPeriod(family: OpenAIModelFamily): number { + // TODO: this is really inefficient on servers with large key pools and we + // are calling it every 50ms, per model family. + const activeKeys = this.keys.filter( (key) => !key.isDisabled && key.modelFamilies.includes(family) ); @@ -318,11 +268,15 @@ export class OpenAIKeyProvider implements KeyProvider { public markRateLimited(keyHash: string) { this.log.debug({ key: keyHash }, "Key rate limited"); const key = this.keys.find((k) => k.hash === keyHash)!; - key.rateLimitedAt = Date.now(); - // DALL-E requests do not send headers telling us when the rate limit will - // be reset so we need to set a fallback value here. Other models will have - // this overwritten by the `updateRateLimits` method. - key.rateLimitRequestsReset = 20000; + const now = Date.now(); + key.rateLimitedAt = now; + + // Most OpenAI reqeuests will provide a `x-ratelimit-reset-requests` header + // header telling us when to try again which will be set in a call to + // `updateRateLimits`. These values below are fallbacks in case the header + // is not provided. + key.rateLimitRequestsReset = 10000; + key.rateLimitedUntil = now + key.rateLimitRequestsReset; } public incrementUsage(keyHash: string, model: string, tokens: number) { @@ -349,6 +303,13 @@ export class OpenAIKeyProvider implements KeyProvider { this.log.warn({ key: key.hash }, `No ratelimit headers; skipping update`); return; } + + const { rateLimitedAt, rateLimitRequestsReset, rateLimitTokensReset } = key; + const rateLimitedUntil = + rateLimitedAt + Math.max(rateLimitRequestsReset, rateLimitTokensReset); + if (rateLimitedUntil > Date.now()) { + key.rateLimitedUntil = rateLimitedUntil; + } } public recheck() { diff --git a/src/shared/key-management/prioritize-keys.ts b/src/shared/key-management/prioritize-keys.ts index f29bdae..f00e4b2 100644 --- a/src/shared/key-management/prioritize-keys.ts +++ b/src/shared/key-management/prioritize-keys.ts @@ -5,9 +5,9 @@ import { Key } from "./index"; * lowest priority. Keys are prioritized in the following order: * * 1. Keys which are not rate limited - * a. If all keys were rate limited recently, select the least-recently + * - If all keys were rate limited recently, select the least-recently * rate limited key. - * b. Otherwise, select the first key. + * - Otherwise, select the first key. * 2. Keys which have not been used in the longest time * 3. Keys according to the custom comparator, if provided * @param keys The list of keys to sort diff --git a/src/shared/models.ts b/src/shared/models.ts index 874c107..2a013d6 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -22,6 +22,8 @@ export type OpenAIModelFamily = | "gpt4-32k" | "gpt4-turbo" | "gpt4o" + | "o1" + | "o1-mini" | "dall-e"; export type AnthropicModelFamily = "claude" | "claude-opus"; export type GoogleAIModelFamily = @@ -54,6 +56,8 @@ export const MODEL_FAMILIES = (( "gpt4-32k", "gpt4-turbo", "gpt4o", + "o1", + "o1-mini", "dall-e", "claude", "claude-opus", @@ -78,6 +82,8 @@ export const MODEL_FAMILIES = (( "azure-gpt4-turbo", "azure-gpt4o", "azure-dall-e", + "azure-o1", + "azure-o1-mini", ] as const); export const LLM_SERVICES = (( @@ -100,6 +106,8 @@ export const MODEL_FAMILY_SERVICE: { "gpt4-turbo": "openai", "gpt4-32k": "openai", gpt4o: "openai", + "o1": "openai", + "o1-mini": "openai", "dall-e": "openai", claude: "anthropic", "claude-opus": "anthropic", @@ -117,6 +125,8 @@ export const MODEL_FAMILY_SERVICE: { "azure-gpt4-turbo": "azure", "azure-gpt4o": "azure", "azure-dall-e": "azure", + "azure-o1": "azure", + "azure-o1-mini": "azure", "gemini-flash": "google-ai", "gemini-pro": "google-ai", "gemini-ultra": "google-ai", @@ -143,6 +153,8 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { "^gpt-3.5-turbo": "turbo", "^text-embedding-ada-002$": "turbo", "^dall-e-\\d{1}$": "dall-e", + "^o1-mini(-\\d{4}-\\d{2}-\\d{2})?$": "o1-mini", + "^o1(-preview)?(-\\d{4}-\\d{2}-\\d{2})?$": "o1", }; export function getOpenAIModelFamily( diff --git a/src/shared/stats.ts b/src/shared/stats.ts index 9d97c36..e02f967 100644 --- a/src/shared/stats.ts +++ b/src/shared/stats.ts @@ -14,6 +14,18 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) { case "gpt4-turbo": cost = 0.00001; break; + case "azure-o1": + case "o1": + // Currently we do not track output tokens separately, and O1 uses + // considerably more output tokens that other models for its hidden + // reasoning. The official O1 pricing is $15/1M input tokens and $60/1M + // output tokens so we will return a higher estimate here. + cost = 0.00002; + break + case "azure-o1-mini": + case "o1-mini": + cost = 0.000005; // $3/1M input tokens, $12/1M output tokens + break case "azure-gpt4-32k": case "gpt4-32k": cost = 0.00006; diff --git a/src/shared/tokenization/tokenizer.ts b/src/shared/tokenization/tokenizer.ts index 864f0bb..65b9db8 100644 --- a/src/shared/tokenization/tokenizer.ts +++ b/src/shared/tokenization/tokenizer.ts @@ -86,6 +86,8 @@ type TokenCountRequest = { req: Request } & ( type TokenCountResult = { token_count: number; + /** Additional tokens for reasoning, if applicable. */ + reasoning_tokens?: number; tokenizer: string; tokenization_duration_ms: number; };