diff --git a/.env.example b/.env.example index a1e9cd2..d111869 100644 --- a/.env.example +++ b/.env.example @@ -34,10 +34,10 @@ # Which model types users are allowed to access. # The following model families are recognized: -# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude +# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo # By default, all models are allowed except for 'dall-e'. To allow DALL-E image # generation, uncomment the line below and add 'dall-e' to the list. -# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude +# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo # URLs from which requests will be blocked. # BLOCKED_ORIGINS=reddit.com,9gag.com @@ -114,6 +114,8 @@ OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS. AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2 +# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure. +AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key # With proxy_key gatekeeper, the password users must provide to access the API. # PROXY_KEY=your-secret-key diff --git a/.gitignore b/.gitignore index 32c1788..e12269b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ build greeting.md node_modules +http-client.private.env.json diff --git a/docs/azure-configuration.md b/docs/azure-configuration.md new file mode 100644 index 0000000..2ee172b --- /dev/null +++ b/docs/azure-configuration.md @@ -0,0 +1,25 @@ +# Configuring the proxy for Azure + +The proxy supports Azure OpenAI Service via the `/proxy/azure/openai` endpoint. The process of setting it up is slightly different from regular OpenAI. + +- [Setting keys](#setting-keys) +- [Model assignment](#model-assignment) + +## Setting keys + +Use the `AZURE_CREDENTIALS` environment variable to set the Azure API keys. + +Like other APIs, you can provide multiple keys separated by commas. Each Azure key, however, is a set of values including the Resource Name, Deployment ID, and API key. These are separated by a colon (`:`). + +For example: +``` +AZURE_CREDENTIALS=contoso-ml:gpt4-8k:0123456789abcdef0123456789abcdef,northwind-corp:testdeployment:0123456789abcdef0123456789abcdef +``` + +## Model assignment +Note that each Azure deployment is assigned a model when you create it in the Microsoft Cognitive Services portal. If you want to use a different model, you'll need to create a new deployment, and therefore a new key to be added to the AZURE_CREDENTIALS environment variable. Each credential only grants access to one model. + +### Supported model IDs +Users can send normal OpenAI model IDs to the proxy to invoke the corresponding models. For the most part they work the same with Azure. GPT-3.5 Turbo has an ID of "gpt-35-turbo" because Azure doesn't allow periods in model names, but the proxy should automatically convert this to the correct ID. + +As noted above, you can only use model IDs for which a deployment has been created and added to the proxy. diff --git a/http-client.env.json b/http-client.env.json new file mode 100644 index 0000000..9586a3f --- /dev/null +++ b/http-client.env.json @@ -0,0 +1,9 @@ +{ + "dev": { + "proxy-host": "http://localhost:7860", + "oai-key-1": "override in http-client.private.env.json", + "proxy-key": "override in http-client.private.env.json", + "azu-resource-name": "override in http-client.private.env.json", + "azu-deployment-id": "override in http-client.private.env.json" + } +} diff --git a/scripts/oai-reverse-proxy.http b/scripts/oai-reverse-proxy.http new file mode 100644 index 0000000..f3c8662 --- /dev/null +++ b/scripts/oai-reverse-proxy.http @@ -0,0 +1,248 @@ +# OAI Reverse Proxy + +### +# @name OpenAI -- Chat Completions +POST https://api.openai.com/v1/chat/completions +Authorization: Bearer {{oai-key-1}} +Content-Type: application/json + +{ + "model": "gpt-3.5-turbo", + "max_tokens": 30, + "stream": false, + "messages": [ + { + "role": "user", + "content": "This is a test prompt." + } + ] +} + +### +# @name OpenAI -- Text Completions +POST https://api.openai.com/v1/completions +Authorization: Bearer {{oai-key-1}} +Content-Type: application/json + +{ + "model": "gpt-3.5-turbo-instruct", + "max_tokens": 30, + "stream": false, + "prompt": "This is a test prompt where" +} + +### +# @name OpenAI -- Create Embedding +POST https://api.openai.com/v1/embeddings +Authorization: Bearer {{oai-key-1}} +Content-Type: application/json + +{ + "model": "text-embedding-ada-002", + "input": "This is a test embedding input." +} + +### +# @name OpenAI -- Get Organizations +GET https://api.openai.com/v1/organizations +Authorization: Bearer {{oai-key-1}} + +### +# @name OpenAI -- Get Models +GET https://api.openai.com/v1/models +Authorization: Bearer {{oai-key-1}} + +### +# @name Azure OpenAI -- Chat Completions +POST https://{{azu-resource-name}}.openai.azure.com/openai/deployments/{{azu-deployment-id}}/chat/completions?api-version=2023-09-01-preview +api-key: {{azu-key-1}} +Content-Type: application/json + +{ + "max_tokens": 1, + "stream": false, + "messages": [ + { + "role": "user", + "content": "This is a test prompt." + } + ] +} + +### +# @name Proxy / OpenAI -- Get Models +GET {{proxy-host}}/proxy/openai/v1/models +Authorization: Bearer {{proxy-key}} + +### +# @name Proxy / OpenAI -- Native Chat Completions +POST {{proxy-host}}/proxy/openai/chat/completions +Authorization: Bearer {{proxy-key}} +Content-Type: application/json + +{ + "model": "gpt-3.5-turbo", + "max_tokens": 20, + "stream": true, + "temperature": 1, + "seed": 123, + "messages": [ + { + "role": "user", + "content": "phrase one" + } + ] +} + +### +# @name Proxy / OpenAI -- Native Text Completions +POST {{proxy-host}}/proxy/openai/v1/turbo-instruct/chat/completions +Authorization: Bearer {{proxy-key}} +Content-Type: application/json + +{ + "model": "gpt-3.5-turbo-instruct", + "max_tokens": 20, + "temperature": 0, + "prompt": "Genshin Impact is a game about", + "stream": false +} + +### +# @name Proxy / OpenAI -- Chat-to-Text API Translation +# Accepts a chat completion request and reformats it to work with the text completion API. `model` is ignored. +POST {{proxy-host}}/proxy/openai/turbo-instruct/chat/completions +Authorization: Bearer {{proxy-key}} +Content-Type: application/json + +{ + "model": "gpt-4", + "max_tokens": 20, + "stream": true, + "messages": [ + { + "role": "user", + "content": "What is the name of the fourth president of the united states?" + }, + { + "role": "assistant", + "content": "That would be George Washington." + }, + { + "role": "user", + "content": "I don't think that's right..." + } + ] +} + +### +# @name Proxy / OpenAI -- Create Embedding +POST {{proxy-host}}/proxy/openai/embeddings +Authorization: Bearer {{proxy-key}} +Content-Type: application/json + +{ + "model": "text-embedding-ada-002", + "input": "This is a test embedding input." +} + + +### +# @name Proxy / Anthropic -- Native Completion (old API) +POST {{proxy-host}}/proxy/anthropic/v1/complete +Authorization: Bearer {{proxy-key}} +anthropic-version: 2023-01-01 +Content-Type: application/json + +{ + "model": "claude-v1.3", + "max_tokens_to_sample": 20, + "temperature": 0.2, + "stream": true, + "prompt": "What is genshin impact\n\n:Assistant:" +} + +### +# @name Proxy / Anthropic -- Native Completion (2023-06-01 API) +POST {{proxy-host}}/proxy/anthropic/v1/complete +Authorization: Bearer {{proxy-key}} +anthropic-version: 2023-06-01 +Content-Type: application/json + +{ + "model": "claude-v1.3", + "max_tokens_to_sample": 20, + "temperature": 0.2, + "stream": true, + "prompt": "What is genshin impact\n\n:Assistant:" +} + +### +# @name Proxy / Anthropic -- OpenAI-to-Anthropic API Translation +POST {{proxy-host}}/proxy/anthropic/v1/chat/completions +Authorization: Bearer {{proxy-key}} +#anthropic-version: 2023-06-01 +Content-Type: application/json + +{ + "model": "gpt-3.5-turbo", + "max_tokens": 20, + "stream": false, + "temperature": 0, + "messages": [ + { + "role": "user", + "content": "What is genshin impact" + } + ] +} + +### +# @name Proxy / AWS Claude -- Native Completion +POST {{proxy-host}}/proxy/aws/claude/v1/complete +Authorization: Bearer {{proxy-key}} +anthropic-version: 2023-01-01 +Content-Type: application/json + +{ + "model": "claude-v2", + "max_tokens_to_sample": 10, + "temperature": 0, + "stream": true, + "prompt": "What is genshin impact\n\n:Assistant:" +} + +### +# @name Proxy / AWS Claude -- OpenAI-to-Anthropic API Translation +POST {{proxy-host}}/proxy/aws/claude/chat/completions +Authorization: Bearer {{proxy-key}} +Content-Type: application/json + +{ + "model": "gpt-3.5-turbo", + "max_tokens": 50, + "stream": true, + "messages": [ + { + "role": "user", + "content": "What is genshin impact?" + } + ] +} + +### +# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation +POST {{proxy-host}}/proxy/google-palm/v1/chat/completions +Authorization: Bearer {{proxy-key}} +Content-Type: application/json + +{ + "model": "gpt-4", + "max_tokens": 42, + "messages": [ + { + "role": "user", + "content": "Hi what is the name of the fourth president of the united states?" + } + ] +} diff --git a/scripts/test_concurrency.ps1 b/scripts/test_concurrency.ps1 new file mode 100644 index 0000000..b802dbe --- /dev/null +++ b/scripts/test_concurrency.ps1 @@ -0,0 +1,40 @@ +$NumThreads = 10 + +$runspacePool = [runspacefactory]::CreateRunspacePool(1, $NumThreads) +$runspacePool.Open() +$runspaces = @() + +$headers = @{ + "Authorization" = "Bearer test" + "anthropic-version" = "2023-01-01" + "Content-Type" = "application/json" +} + +$payload = @{ + model = "claude-v2" + max_tokens_to_sample = 40 + temperature = 0 + stream = $true + prompt = "Test prompt, please reply with lorem ipsum`n`n:Assistant:" +} | ConvertTo-Json + +for ($i = 1; $i -le $NumThreads; $i++) { + Write-Host "Starting thread $i" + $runspace = [powershell]::Create() + $runspace.AddScript({ + param($i, $headers, $payload) + $response = Invoke-WebRequest -Uri "http://localhost:7860/proxy/aws/claude/v1/complete" -Method Post -Headers $headers -Body $payload + Write-Host "Response from server: $($response.StatusCode)" + }).AddArgument($i).AddArgument($headers).AddArgument($payload) + + $runspace.RunspacePool = $runspacePool + $runspaces += [PSCustomObject]@{ Pipe = $runspace; Status = $runspace.BeginInvoke() } +} + +$runspaces | ForEach-Object { + $_.Pipe.EndInvoke($_.Status) + $_.Pipe.Dispose() +} + +$runspacePool.Close() +$runspacePool.Dispose() diff --git a/src/config.ts b/src/config.ts index 73b8253..9fd6da9 100644 --- a/src/config.ts +++ b/src/config.ts @@ -33,6 +33,17 @@ type Config = { * @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2` */ awsCredentials?: string; + /** + * Comma-delimited list of Azure OpenAI credentials. Each credential item + * should be a colon-delimited list of Azure resource name, deployment ID, and + * API key. + * + * The resource name is the subdomain in your Azure OpenAI deployment's URL, + * e.g. `https://resource-name.openai.azure.com + * + * @example `AZURE_CREDENTIALS=resource_name_1:deployment_id_1:api_key_1,resource_name_2:deployment_id_2:api_key_2` + */ + azureCredentials?: string; /** * The proxy key to require for requests. Only applicable if the user * management mode is set to 'proxy_key', and required if so. @@ -188,6 +199,7 @@ export const config: Config = { anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""), googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""), awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""), + azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""), adminKey: getEnvWithDefault("ADMIN_KEY", ""), gatekeeper: getEnvWithDefault("GATEKEEPER", "none"), @@ -219,6 +231,10 @@ export const config: Config = { "claude", "bison", "aws-claude", + "azure-turbo", + "azure-gpt4", + "azure-gpt4-turbo", + "azure-gpt4-32k", ]), rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")), rejectMessage: getEnvWithDefault( @@ -352,6 +368,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [ "anthropicKey", "googlePalmKey", "awsCredentials", + "azureCredentials", "proxyKey", "adminKey", "rejectPhrases", @@ -369,6 +386,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [ "useInsecureCookies", "staticServiceInfo", "checkKeys", + "allowedModelFamilies", ]; const getKeys = Object.keys as (obj: T) => Array; @@ -417,6 +435,7 @@ function getEnvWithDefault(env: string | string[], defaultValue: T): T { "ANTHROPIC_KEY", "GOOGLE_PALM_KEY", "AWS_CREDENTIALS", + "AZURE_CREDENTIALS", ].includes(String(env)) ) { return value as unknown as T; diff --git a/src/info-page.ts b/src/info-page.ts index 4888cb8..4084dfe 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -1,3 +1,4 @@ +/** This whole module really sucks */ import fs from "fs"; import { Request, Response } from "express"; import showdown from "showdown"; @@ -5,11 +6,16 @@ import { config, listConfig } from "./config"; import { AnthropicKey, AwsBedrockKey, + AzureOpenAIKey, GooglePalmKey, keyPool, OpenAIKey, } from "./shared/key-management"; -import { ModelFamily, OpenAIModelFamily } from "./shared/models"; +import { + AzureOpenAIModelFamily, + ModelFamily, + OpenAIModelFamily, +} from "./shared/models"; import { getUniqueIps } from "./proxy/rate-limit"; import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue"; import { getTokenCostUsd, prettyTokens } from "./shared/stats"; @@ -23,6 +29,8 @@ let infoPageLastUpdated = 0; type KeyPoolKey = ReturnType[0]; const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey => k.service === "openai"; +const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey => + k.service === "azure"; const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey => k.service === "anthropic"; const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey => @@ -48,6 +56,7 @@ type ServiceAggregates = { anthropicKeys?: number; palmKeys?: number; awsKeys?: number; + azureKeys?: number; proompts: number; tokens: number; tokenCost: number; @@ -62,17 +71,15 @@ const serviceStats = new Map(); export const handleInfoPage = (req: Request, res: Response) => { if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) { - res.send(infoPageHtml); - return; + return res.send(infoPageHtml); } - // Sometimes huggingface doesn't send the host header and makes us guess. const baseUrl = process.env.SPACE_ID && !req.get("host")?.includes("hf.space") ? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID) : req.protocol + "://" + req.get("host"); - infoPageHtml = buildInfoPageHtml(baseUrl); + infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy"); infoPageLastUpdated = Date.now(); res.send(infoPageHtml); @@ -95,6 +102,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) { const anthropicKeys = serviceStats.get("anthropicKeys") || 0; const palmKeys = serviceStats.get("palmKeys") || 0; const awsKeys = serviceStats.get("awsKeys") || 0; + const azureKeys = serviceStats.get("azureKeys") || 0; const proompts = serviceStats.get("proompts") || 0; const tokens = serviceStats.get("tokens") || 0; const tokenCost = serviceStats.get("tokenCost") || 0; @@ -102,16 +110,15 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) { const allowDalle = config.allowedModelFamilies.includes("dall-e"); const endpoints = { - ...(openaiKeys ? { openai: baseUrl + "/proxy/openai" } : {}), - ...(openaiKeys - ? { ["openai2"]: baseUrl + "/proxy/openai/turbo-instruct" } - : {}), + ...(openaiKeys ? { openai: baseUrl + "/openai" } : {}), + ...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}), ...(openaiKeys && allowDalle - ? { ["openai-image"]: baseUrl + "/proxy/openai-image" } + ? { ["openai-image"]: baseUrl + "/openai-image" } : {}), - ...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}), - ...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}), - ...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}), + ...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}), + ...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}), + ...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}), + ...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}), }; const stats = { @@ -120,13 +127,17 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) { ...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}), }; - const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys }; + const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys }; + for (const key of Object.keys(keyInfo)) { + if (!(keyInfo as any)[key]) delete (keyInfo as any)[key]; + } const providerInfo = { ...(openaiKeys ? getOpenAIInfo() : {}), ...(anthropicKeys ? getAnthropicInfo() : {}), - ...(palmKeys ? { "palm-bison": getPalmInfo() } : {}), - ...(awsKeys ? { "aws-claude": getAwsInfo() } : {}), + ...(palmKeys ? getPalmInfo() : {}), + ...(awsKeys ? getAwsInfo() : {}), + ...(azureKeys ? getAzureInfo() : {}), }; if (hideFullInfo) { @@ -188,6 +199,7 @@ function addKeyToAggregates(k: KeyPoolKey) { increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0); increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0); increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0); + increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0); let sumTokens = 0; let sumCost = 0; @@ -201,17 +213,26 @@ function addKeyToAggregates(k: KeyPoolKey) { Boolean(k.lastChecked) ? 0 : 1 ); - // Technically this would not account for keys that have tokens recorded - // on models they aren't provisioned for, but that would be strange k.modelFamilies.forEach((f) => { const tokens = k[`${f}Tokens`]; sumTokens += tokens; sumCost += getTokenCostUsd(f, tokens); increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0); increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0); increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); + increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0); + increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0); + }); + break; + case "azure": + if (!keyIsAzureKey(k)) throw new Error("Invalid key type"); + k.modelFamilies.forEach((f) => { + const tokens = k[`${f}Tokens`]; + sumTokens += tokens; + sumCost += getTokenCostUsd(f, tokens); + increment(modelStats, `${f}__tokens`, tokens); + increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); + increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); }); break; case "anthropic": { @@ -381,11 +402,13 @@ function getPalmInfo() { const cost = getTokenCostUsd("bison", tokens); return { - usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, - activeKeys: bisonInfo.active, - revokedKeys: bisonInfo.revoked, - proomptersInQueue: bisonInfo.queued, - estimatedQueueTime: bisonInfo.queueTime, + bison: { + usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, + activeKeys: bisonInfo.active, + revokedKeys: bisonInfo.revoked, + proomptersInQueue: bisonInfo.queued, + estimatedQueueTime: bisonInfo.queueTime, + }, }; } @@ -408,15 +431,59 @@ function getAwsInfo() { : `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`; return { - usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, - activeKeys: awsInfo.active, - revokedKeys: awsInfo.revoked, - proomptersInQueue: awsInfo.queued, - estimatedQueueTime: awsInfo.queueTime, - ...(logged > 0 ? { privacy: logMsg } : {}), + "aws-claude": { + usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, + activeKeys: awsInfo.active, + revokedKeys: awsInfo.revoked, + proomptersInQueue: awsInfo.queued, + estimatedQueueTime: awsInfo.queueTime, + ...(logged > 0 ? { privacy: logMsg } : {}), + }, }; } +function getAzureInfo() { + const azureFamilies = [ + "azure-turbo", + "azure-gpt4", + "azure-gpt4-turbo", + "azure-gpt4-32k", + ] as const; + + const azureInfo: { + [modelFamily in AzureOpenAIModelFamily]?: { + usage?: string; + activeKeys: number; + revokedKeys?: number; + proomptersInQueue?: number; + estimatedQueueTime?: string; + }; + } = {}; + for (const family of azureFamilies) { + const familyAllowed = config.allowedModelFamilies.includes(family); + const activeKeys = modelStats.get(`${family}__active`) || 0; + + if (!familyAllowed || activeKeys === 0) continue; + + azureInfo[family] = { + activeKeys, + revokedKeys: modelStats.get(`${family}__revoked`) || 0, + }; + + const queue = getQueueInformation(family); + azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue; + azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime; + + const tokens = modelStats.get(`${family}__tokens`) || 0; + const cost = getTokenCostUsd(family, tokens); + azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString( + cost + )}`; + } + + return azureInfo; +} + const customGreeting = fs.existsSync("greeting.md") ? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}` : ""; @@ -430,10 +497,10 @@ function buildInfoPageHeader(converter: showdown.Converter, title: string) { let infoBody = ` # ${title}`; if (config.promptLogging) { - infoBody += `\n## Prompt logging is enabled! -The server operator has enabled prompt logging. The prompts you send to this proxy and the AI responses you receive may be saved. + infoBody += `\n## Prompt Logging Enabled +This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps. -Logs are anonymous and do not contain IP addresses or timestamps. [You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/prompt-logging/index.ts). +[You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/shared/prompt-logging/index.ts). **If you are uncomfortable with this, don't send prompts to this proxy!**`; } @@ -570,8 +637,6 @@ function escapeHtml(unsafe: string) { } function getExternalUrlForHuggingfaceSpaceId(spaceId: string) { - // Huggingface broke their amazon elb config and no longer sends the - // x-forwarded-host header. This is a workaround. try { const [username, spacename] = spaceId.split("/"); return `https://${username}-${spacename.replace(/_/g, "-")}.hf.space`; diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index 2f43762..9c8d6dd 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -11,7 +11,7 @@ import { createPreprocessorMiddleware, stripHeaders, signAwsRequest, - finalizeAwsRequest, + finalizeSignedRequest, createOnProxyReqHandler, blockZoomerOrigins, } from "./middleware/request"; @@ -30,7 +30,11 @@ const getModelsResponse = () => { if (!config.awsCredentials) return { object: "list", data: [] }; - const variants = ["anthropic.claude-v1", "anthropic.claude-v2"]; + const variants = [ + "anthropic.claude-v1", + "anthropic.claude-v2", + "anthropic.claude-v2:1", + ]; const models = variants.map((id) => ({ id, @@ -134,7 +138,7 @@ const awsProxy = createQueueMiddleware({ applyQuotaLimits, blockZoomerOrigins, stripHeaders, - finalizeAwsRequest, + finalizeSignedRequest, ], }), proxyRes: createOnProxyResHandler([awsResponseHandler]), @@ -183,7 +187,7 @@ function maybeReassignModel(req: Request) { req.body.model = "anthropic.claude-v1"; } else { // User's client requested v2 or possibly some OpenAI model, default to v2 - req.body.model = "anthropic.claude-v2"; + req.body.model = "anthropic.claude-v2:1"; } // TODO: Handle claude-instant } diff --git a/src/proxy/azure.ts b/src/proxy/azure.ts new file mode 100644 index 0000000..45b9e95 --- /dev/null +++ b/src/proxy/azure.ts @@ -0,0 +1,140 @@ +import { RequestHandler, Router } from "express"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { config } from "../config"; +import { keyPool } from "../shared/key-management"; +import { + ModelFamily, + AzureOpenAIModelFamily, + getAzureOpenAIModelFamily, +} from "../shared/models"; +import { logger } from "../logger"; +import { KNOWN_OPENAI_MODELS } from "./openai"; +import { createQueueMiddleware } from "./queue"; +import { ipLimiter } from "./rate-limit"; +import { handleProxyError } from "./middleware/common"; +import { + applyQuotaLimits, + blockZoomerOrigins, + createOnProxyReqHandler, + createPreprocessorMiddleware, + finalizeSignedRequest, + limitCompletions, + stripHeaders, +} from "./middleware/request"; +import { + createOnProxyResHandler, + ProxyResHandlerWithBody, +} from "./middleware/response"; +import { addAzureKey } from "./middleware/request/add-azure-key"; + +let modelsCache: any = null; +let modelsCacheTime = 0; + +function getModelsResponse() { + if (new Date().getTime() - modelsCacheTime < 1000 * 60) { + return modelsCache; + } + + let available = new Set(); + for (const key of keyPool.list()) { + if (key.isDisabled || key.service !== "azure") continue; + key.modelFamilies.forEach((family) => + available.add(family as AzureOpenAIModelFamily) + ); + } + const allowed = new Set(config.allowedModelFamilies); + available = new Set([...available].filter((x) => allowed.has(x))); + + const models = KNOWN_OPENAI_MODELS.map((id) => ({ + id, + object: "model", + created: new Date().getTime(), + owned_by: "azure", + permission: [ + { + id: "modelperm-" + id, + object: "model_permission", + created: new Date().getTime(), + organization: "*", + group: null, + is_blocking: false, + }, + ], + root: id, + parent: null, + })).filter((model) => available.has(getAzureOpenAIModelFamily(model.id))); + + modelsCache = { object: "list", data: models }; + modelsCacheTime = new Date().getTime(); + + return modelsCache; +} + +const handleModelRequest: RequestHandler = (_req, res) => { + res.status(200).json(getModelsResponse()); +}; + +const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + if (config.promptLogging) { + const host = req.get("host"); + body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`; + } + + if (req.tokenizerInfo) { + body.proxy_tokenizer = req.tokenizerInfo; + } + + res.status(200).json(body); +}; + +const azureOpenAIProxy = createQueueMiddleware({ + beforeProxy: addAzureKey, + proxyMiddleware: createProxyMiddleware({ + target: "will be set by router", + router: (req) => { + if (!req.signedRequest) throw new Error("signedRequest not set"); + const { hostname, path } = req.signedRequest; + return `https://${hostname}${path}`; + }, + changeOrigin: true, + selfHandleResponse: true, + logger, + on: { + proxyReq: createOnProxyReqHandler({ + pipeline: [ + applyQuotaLimits, + limitCompletions, + blockZoomerOrigins, + stripHeaders, + finalizeSignedRequest, + ], + }), + proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]), + error: handleProxyError, + }, + }), +}); + +const azureOpenAIRouter = Router(); +azureOpenAIRouter.get("/v1/models", handleModelRequest); +azureOpenAIRouter.post( + "/v1/chat/completions", + ipLimiter, + createPreprocessorMiddleware({ + inApi: "openai", + outApi: "openai", + service: "azure", + }), + azureOpenAIProxy +); + +export const azure = azureOpenAIRouter; diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 24cfef9..52091c0 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -59,7 +59,7 @@ export function writeErrorResponse( res.write(`data: [DONE]\n\n`); res.end(); } else { - if (req.tokenizerInfo && errorPayload.error) { + if (req.tokenizerInfo && typeof errorPayload.error === "object") { errorPayload.error.proxy_tokenizer = req.tokenizerInfo; } res.status(statusCode).json(errorPayload); diff --git a/src/proxy/middleware/request/add-azure-key.ts b/src/proxy/middleware/request/add-azure-key.ts new file mode 100644 index 0000000..5c34d34 --- /dev/null +++ b/src/proxy/middleware/request/add-azure-key.ts @@ -0,0 +1,50 @@ +import { AzureOpenAIKey, keyPool } from "../../../shared/key-management"; +import { RequestPreprocessor } from "."; + +export const addAzureKey: RequestPreprocessor = (req) => { + const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai"; + const serviceValid = req.service === "azure"; + if (!apisValid || !serviceValid) { + throw new Error("addAzureKey called on invalid request"); + } + + if (!req.body?.model) { + throw new Error("You must specify a model with your request."); + } + + const model = req.body.model.startsWith("azure-") + ? req.body.model + : `azure-${req.body.model}`; + + req.key = keyPool.get(model); + req.body.model = model; + + req.log.info( + { key: req.key.hash, model }, + "Assigned Azure OpenAI key to request" + ); + + const cred = req.key as AzureOpenAIKey; + const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred); + + req.signedRequest = { + method: "POST", + protocol: "https:", + hostname: `${resourceName}.openai.azure.com`, + path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`, + headers: { + ["host"]: `${resourceName}.openai.azure.com`, + ["content-type"]: "application/json", + ["api-key"]: apiKey, + }, + body: JSON.stringify(req.body), + }; +}; + +function getCredentialsFromKey(key: AzureOpenAIKey) { + const [resourceName, deploymentId, apiKey] = key.key.split(":"); + if (!resourceName || !deploymentId || !apiKey) { + throw new Error("Assigned Azure OpenAI key is not in the correct format."); + } + return { resourceName, deploymentId, apiKey }; +} diff --git a/src/proxy/middleware/request/add-key.ts b/src/proxy/middleware/request/add-key.ts index bbf38b7..49a7e88 100644 --- a/src/proxy/middleware/request/add-key.ts +++ b/src/proxy/middleware/request/add-key.ts @@ -80,6 +80,10 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => { `?key=${assignedKey.key}` ); break; + case "azure": + const azureKey = assignedKey.key; + proxyReq.setHeader("api-key", azureKey); + break; case "aws": throw new Error( "add-key should not be used for AWS security credentials. Use sign-aws-request instead." diff --git a/src/proxy/middleware/request/finalize-aws-request.ts b/src/proxy/middleware/request/finalize-signed-request.ts similarity index 74% rename from src/proxy/middleware/request/finalize-aws-request.ts rename to src/proxy/middleware/request/finalize-signed-request.ts index 000a533..d8c6622 100644 --- a/src/proxy/middleware/request/finalize-aws-request.ts +++ b/src/proxy/middleware/request/finalize-signed-request.ts @@ -1,11 +1,11 @@ import type { ProxyRequestMiddleware } from "."; /** - * For AWS requests, the body is signed earlier in the request pipeline, before - * the proxy middleware. This function just assigns the path and headers to the - * proxy request. + * For AWS/Azure requests, the body is signed earlier in the request pipeline, + * before the proxy middleware. This function just assigns the path and headers + * to the proxy request. */ -export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => { +export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => { if (!req.signedRequest) { throw new Error("Expected req.signedRequest to be set"); } diff --git a/src/proxy/middleware/request/index.ts b/src/proxy/middleware/request/index.ts index 565b9ac..8b31ecf 100644 --- a/src/proxy/middleware/request/index.ts +++ b/src/proxy/middleware/request/index.ts @@ -22,7 +22,7 @@ export { addKey, addKeyForEmbeddingsRequest } from "./add-key"; export { addAnthropicPreamble } from "./add-anthropic-preamble"; export { blockZoomerOrigins } from "./block-zoomer-origins"; export { finalizeBody } from "./finalize-body"; -export { finalizeAwsRequest } from "./finalize-aws-request"; +export { finalizeSignedRequest } from "./finalize-signed-request"; export { limitCompletions } from "./limit-completions"; export { stripHeaders } from "./strip-headers"; diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index c68d347..1dbfb0e 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -289,15 +289,17 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( switch (service) { case "openai": case "google-palm": - if (errorPayload.error?.code === "content_policy_violation") { - errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`; + case "azure": + const filteredCodes = ["content_policy_violation", "content_filter"]; + if (filteredCodes.includes(errorPayload.error?.code)) { + errorPayload.proxy_note = `Request was filtered by the upstream API's content moderation system. Modify your prompt and try again.`; refundLastAttempt(req); } else if (errorPayload.error?.code === "billing_hard_limit_reached") { // For some reason, some models return this 400 error instead of the // same 429 billing error that other models return. handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload); } else { - errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`; + errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`; } break; case "anthropic": @@ -342,7 +344,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( handleAwsRateLimitError(req, errorPayload); break; case "google-palm": - throw new Error("Rate limit handling not implemented for PaLM"); + case "azure": + errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`; + break; default: assertNever(service); } @@ -369,6 +373,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( case "aws": errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`; break; + case "azure": + errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`; + break; default: assertNever(service); } diff --git a/src/proxy/middleware/response/streaming/sse-message-transformer.ts b/src/proxy/middleware/response/streaming/sse-message-transformer.ts index 6da55b9..9deb9a3 100644 --- a/src/proxy/middleware/response/streaming/sse-message-transformer.ts +++ b/src/proxy/middleware/response/streaming/sse-message-transformer.ts @@ -28,6 +28,7 @@ type SSEMessageTransformerOptions = TransformOptions & { export class SSEMessageTransformer extends Transform { private lastPosition: number; private msgCount: number; + private readonly inputFormat: APIFormat; private readonly transformFn: StreamingCompletionTransformer; private readonly log; private readonly fallbackId: string; @@ -42,6 +43,7 @@ export class SSEMessageTransformer extends Transform { options.inputFormat, options.inputApiVersion ); + this.inputFormat = options.inputFormat; this.fallbackId = options.requestId; this.fallbackModel = options.requestedModel; this.log.debug( @@ -67,6 +69,17 @@ export class SSEMessageTransformer extends Transform { }); this.lastPosition = newPosition; + // Special case for Azure OpenAI, which is 99% the same as OpenAI but + // sometimes emits an extra event at the beginning of the stream with the + // content moderation system's response to the prompt. A lot of frontends + // don't expect this and neither does our event aggregator so we drop it. + if (this.inputFormat === "openai" && this.msgCount <= 1) { + if (originalMessage.includes("prompt_filter_results")) { + this.log.debug("Dropping Azure OpenAI content moderation SSE event"); + return callback(); + } + } + this.emit("originalMessage", originalMessage); // Some events may not be transformed, e.g. ping events diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 74bcaa6..6617874 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -24,7 +24,7 @@ import { import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response"; // https://platform.openai.com/docs/models/overview -const KNOWN_MODELS = [ +export const KNOWN_OPENAI_MODELS = [ "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", @@ -46,7 +46,7 @@ const KNOWN_MODELS = [ let modelsCache: any = null; let modelsCacheTime = 0; -export function generateModelList(models = KNOWN_MODELS) { +export function generateModelList(models = KNOWN_OPENAI_MODELS) { let available = new Set(); for (const key of keyPool.list()) { if (key.isDisabled || key.service !== "openai") continue; diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index 8055c06..a971e30 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -26,6 +26,7 @@ import { assertNever } from "../shared/utils"; import { logger } from "../logger"; import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit"; import { RequestPreprocessor } from "./middleware/request"; +import { handleProxyError } from "./middleware/common"; const queue: Request[] = []; const log = logger.child({ module: "request-queue" }); @@ -34,7 +35,7 @@ const log = logger.child({ module: "request-queue" }); const AGNAI_CONCURRENCY_LIMIT = 5; /** Maximum number of queue slots for individual users. */ const USER_CONCURRENCY_LIMIT = 1; -const MIN_HEARTBEAT_SIZE = 512; +const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512"); const MAX_HEARTBEAT_SIZE = 1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024"); const HEARTBEAT_INTERVAL = @@ -358,12 +359,16 @@ export function createQueueMiddleware({ return (req, res, next) => { req.proceed = async () => { if (beforeProxy) { - // Hack to let us run asynchronous middleware before the - // http-proxy-middleware handler. This is used to sign AWS requests - // before they are proxied, as the signing is asynchronous. - // Unlike RequestPreprocessors, this runs every time the request is - // dequeued, not just the first time. - await beforeProxy(req); + try { + // Hack to let us run asynchronous middleware before the + // http-proxy-middleware handler. This is used to sign AWS requests + // before they are proxied, as the signing is asynchronous. + // Unlike RequestPreprocessors, this runs every time the request is + // dequeued, not just the first time. + await beforeProxy(req); + } catch (err) { + return handleProxyError(err, req, res); + } } proxyMiddleware(req, res, next); }; diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index 982b761..63483ce 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -6,6 +6,7 @@ import { openaiImage } from "./openai-image"; import { anthropic } from "./anthropic"; import { googlePalm } from "./palm"; import { aws } from "./aws"; +import { azure } from "./azure"; const proxyRouter = express.Router(); proxyRouter.use((req, _res, next) => { @@ -32,6 +33,7 @@ proxyRouter.use("/openai-image", addV1, openaiImage); proxyRouter.use("/anthropic", addV1, anthropic); proxyRouter.use("/google-palm", addV1, googlePalm); proxyRouter.use("/aws/claude", addV1, aws); +proxyRouter.use("/azure/openai", addV1, azure); // Redirect browser requests to the homepage. proxyRouter.get("*", (req, res, next) => { const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); diff --git a/src/shared/key-management/anthropic/checker.ts b/src/shared/key-management/anthropic/checker.ts index dbe8ce9..b930489 100644 --- a/src/shared/key-management/anthropic/checker.ts +++ b/src/shared/key-management/anthropic/checker.ts @@ -26,46 +26,23 @@ type AnthropicAPIError = { type UpdateFn = typeof AnthropicKeyProvider.prototype.update; export class AnthropicKeyChecker extends KeyCheckerBase { - private readonly updateKey: UpdateFn; - constructor(keys: AnthropicKey[], updateKey: UpdateFn) { super(keys, { service: "anthropic", keyCheckPeriod: KEY_CHECK_PERIOD, minCheckInterval: MIN_CHECK_INTERVAL, + updateKey, }); - this.updateKey = updateKey; } - protected async checkKey(key: AnthropicKey) { - if (key.isDisabled) { - this.log.warn({ key: key.hash }, "Skipping check for disabled key."); - this.scheduleNextCheck(); - return; - } - - this.log.debug({ key: key.hash }, "Checking key..."); - let isInitialCheck = !key.lastChecked; - try { - const [{ pozzed }] = await Promise.all([this.testLiveness(key)]); - const updates = { isPozzed: pozzed }; - this.updateKey(key.hash, updates); - this.log.info( - { key: key.hash, models: key.modelFamilies }, - "Key check complete." - ); - } catch (error) { - // touch the key so we don't check it again for a while - this.updateKey(key.hash, {}); - this.handleAxiosError(key, error as AxiosError); - } - - this.lastCheck = Date.now(); - // Only enqueue the next check if this wasn't a startup check, since those - // are batched together elsewhere. - if (!isInitialCheck) { - this.scheduleNextCheck(); - } + protected async testKeyOrFail(key: AnthropicKey) { + const [{ pozzed }] = await Promise.all([this.testLiveness(key)]); + const updates = { isPozzed: pozzed }; + this.updateKey(key.hash, updates); + this.log.info( + { key: key.hash, models: key.modelFamilies }, + "Checked key." + ); } protected handleAxiosError(key: AnthropicKey, error: AxiosError) { @@ -84,6 +61,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase { { key: key.hash, error: error.message }, "Key is rate limited. Rechecking in 10 seconds." ); + 0; const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000); this.updateKey(key.hash, { lastChecked: next }); break; diff --git a/src/shared/key-management/aws/checker.ts b/src/shared/key-management/aws/checker.ts index 4640735..e2b5751 100644 --- a/src/shared/key-management/aws/checker.ts +++ b/src/shared/key-management/aws/checker.ts @@ -32,58 +32,36 @@ type GetLoggingConfigResponse = { type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update; export class AwsKeyChecker extends KeyCheckerBase { - private readonly updateKey: UpdateFn; - constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) { super(keys, { service: "aws", keyCheckPeriod: KEY_CHECK_PERIOD, minCheckInterval: MIN_CHECK_INTERVAL, + updateKey, }); - this.updateKey = updateKey; } - protected async checkKey(key: AwsBedrockKey) { - if (key.isDisabled) { - this.log.warn({ key: key.hash }, "Skipping check for disabled key."); - this.scheduleNextCheck(); - return; + protected async testKeyOrFail(key: AwsBedrockKey) { + // Only check models on startup. For now all models must be available to + // the proxy because we don't route requests to different keys. + const modelChecks: Promise[] = []; + const isInitialCheck = !key.lastChecked; + if (isInitialCheck) { + modelChecks.push(this.invokeModel("anthropic.claude-v1", key)); + modelChecks.push(this.invokeModel("anthropic.claude-v2", key)); } - this.log.debug({ key: key.hash }, "Checking key..."); - let isInitialCheck = !key.lastChecked; - try { - // Only check models on startup. For now all models must be available to - // the proxy because we don't route requests to different keys. - const modelChecks: Promise[] = []; - if (isInitialCheck) { - modelChecks.push(this.invokeModel("anthropic.claude-v1", key)); - modelChecks.push(this.invokeModel("anthropic.claude-v2", key)); - } + await Promise.all(modelChecks); + await this.checkLoggingConfiguration(key); - await Promise.all(modelChecks); - await this.checkLoggingConfiguration(key); - - this.log.info( - { - key: key.hash, - models: key.modelFamilies, - logged: key.awsLoggingStatus, - }, - "Key check complete." - ); - } catch (error) { - this.handleAxiosError(key, error as AxiosError); - } - - this.updateKey(key.hash, {}); - - this.lastCheck = Date.now(); - // Only enqueue the next check if this wasn't a startup check, since those - // are batched together elsewhere. - if (!isInitialCheck) { - this.scheduleNextCheck(); - } + this.log.info( + { + key: key.hash, + models: key.modelFamilies, + logged: key.awsLoggingStatus, + }, + "Checked key." + ); } protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) { diff --git a/src/shared/key-management/azure/checker.ts b/src/shared/key-management/azure/checker.ts new file mode 100644 index 0000000..c7705a3 --- /dev/null +++ b/src/shared/key-management/azure/checker.ts @@ -0,0 +1,149 @@ +import axios, { AxiosError } from "axios"; +import { KeyCheckerBase } from "../key-checker-base"; +import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider"; +import { getAzureOpenAIModelFamily } from "../../models"; + +const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds +const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes +const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com"; +const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) => + `https://${AZURE_HOST.replace( + "%RESOURCE_NAME%", + resourceName + )}/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`; + +type AzureError = { + error: { + message: string; + type: string | null; + param: string; + code: string; + status: number; + }; +}; +type UpdateFn = typeof AzureOpenAIKeyProvider.prototype.update; + +export class AzureOpenAIKeyChecker extends KeyCheckerBase { + constructor(keys: AzureOpenAIKey[], updateKey: UpdateFn) { + super(keys, { + service: "azure", + keyCheckPeriod: KEY_CHECK_PERIOD, + minCheckInterval: MIN_CHECK_INTERVAL, + recurringChecksEnabled: false, + updateKey, + }); + } + + protected async testKeyOrFail(key: AzureOpenAIKey) { + const model = await this.testModel(key); + this.log.info( + { key: key.hash, deploymentModel: model }, + "Checked key." + ); + this.updateKey(key.hash, { modelFamilies: [model] }); + } + + // provided api-key header isn't valid (401) + // { + // "error": { + // "code": "401", + // "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource." + // } + // } + + // api key correct but deployment id is wrong (404) + // { + // "error": { + // "code": "DeploymentNotFound", + // "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again." + // } + // } + + // resource name is wrong (node will throw ENOTFOUND) + + // rate limited (429) + // TODO: try to reproduce this + + protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) { + if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) { + const data = error.response.data; + const status = data.error.status; + const errorType = data.error.code || data.error.type; + switch (errorType) { + case "DeploymentNotFound": + this.log.warn( + { key: key.hash, errorType, error: error.response.data }, + "Key is revoked or deployment ID is incorrect. Disabling key." + ); + return this.updateKey(key.hash, { + isDisabled: true, + isRevoked: true, + }); + case "401": + this.log.warn( + { key: key.hash, errorType, error: error.response.data }, + "Key is disabled or incorrect. Disabling key." + ); + return this.updateKey(key.hash, { + isDisabled: true, + isRevoked: true, + }); + default: + this.log.error( + { key: key.hash, errorType, error: error.response.data, status }, + "Unknown Azure API error while checking key. Please report this." + ); + return this.updateKey(key.hash, { lastChecked: Date.now() }); + } + } + + const { response, code } = error; + if (code === "ENOTFOUND") { + this.log.warn( + { key: key.hash, error: error.message }, + "Resource name is probably incorrect. Disabling key." + ); + return this.updateKey(key.hash, { isDisabled: true, isRevoked: true }); + } + + const { headers, status, data } = response ?? {}; + this.log.error( + { key: key.hash, status, headers, data, error: error.message }, + "Network error while checking key; trying this key again in a minute." + ); + const oneMinute = 60 * 1000; + const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute); + this.updateKey(key.hash, { lastChecked: next }); + } + + private async testModel(key: AzureOpenAIKey) { + const { apiKey, deploymentId, resourceName } = + AzureOpenAIKeyChecker.getCredentialsFromKey(key); + const url = POST_CHAT_COMPLETIONS(resourceName, deploymentId); + const testRequest = { + max_tokens: 1, + stream: false, + messages: [{ role: "user", content: "" }], + }; + const { data } = await axios.post(url, testRequest, { + headers: { "Content-Type": "application/json", "api-key": apiKey }, + }); + + return getAzureOpenAIModelFamily(data.model); + } + + static errorIsAzureError(error: AxiosError): error is AxiosError { + const data = error.response?.data as any; + return data?.error?.code || data?.error?.type; + } + + static getCredentialsFromKey(key: AzureOpenAIKey) { + const [resourceName, deploymentId, apiKey] = key.key.split(":"); + if (!resourceName || !deploymentId || !apiKey) { + throw new Error( + "Invalid Azure credential format. Refer to .env.example and ensure your credentials are in the format RESOURCE_NAME:DEPLOYMENT_ID:API_KEY with commas between each credential set." + ); + } + return { resourceName, deploymentId, apiKey }; + } +} diff --git a/src/shared/key-management/azure/provider.ts b/src/shared/key-management/azure/provider.ts new file mode 100644 index 0000000..5256b7e --- /dev/null +++ b/src/shared/key-management/azure/provider.ts @@ -0,0 +1,212 @@ +import crypto from "crypto"; +import { Key, KeyProvider } from ".."; +import { config } from "../../../config"; +import { logger } from "../../../logger"; +import type { AzureOpenAIModelFamily } from "../../models"; +import { getAzureOpenAIModelFamily } from "../../models"; +import { OpenAIModel } from "../openai/provider"; +import { AzureOpenAIKeyChecker } from "./checker"; +import { AwsKeyChecker } from "../aws/checker"; + +export type AzureOpenAIModel = Exclude; + +type AzureOpenAIKeyUsage = { + [K in AzureOpenAIModelFamily as `${K}Tokens`]: number; +}; + +export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage { + readonly service: "azure"; + readonly modelFamilies: AzureOpenAIModelFamily[]; + /** The time at which this key was last rate limited. */ + rateLimitedAt: number; + /** The time until which this key is rate limited. */ + rateLimitedUntil: number; + contentFiltering: boolean; +} + +/** + * Upon being rate limited, a key will be locked out for this many milliseconds + * while we wait for other concurrent requests to finish. + */ +const RATE_LIMIT_LOCKOUT = 4000; +/** + * Upon assigning a key, we will wait this many milliseconds before allowing it + * to be used again. This is to prevent the queue from flooding a key with too + * many requests while we wait to learn whether previous ones succeeded. + */ +const KEY_REUSE_DELAY = 250; + +export class AzureOpenAIKeyProvider implements KeyProvider { + readonly service = "azure"; + + private keys: AzureOpenAIKey[] = []; + private checker?: AzureOpenAIKeyChecker; + private log = logger.child({ module: "key-provider", service: this.service }); + + constructor() { + const keyConfig = config.azureCredentials; + if (!keyConfig) { + this.log.warn( + "AZURE_CREDENTIALS is not set. Azure OpenAI API will not be available." + ); + return; + } + let bareKeys: string[]; + bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; + for (const key of bareKeys) { + const newKey: AzureOpenAIKey = { + key, + service: this.service, + modelFamilies: ["azure-gpt4"], + isDisabled: false, + isRevoked: false, + promptCount: 0, + lastUsed: 0, + rateLimitedAt: 0, + rateLimitedUntil: 0, + contentFiltering: false, + hash: `azu-${crypto + .createHash("sha256") + .update(key) + .digest("hex") + .slice(0, 8)}`, + lastChecked: 0, + "azure-turboTokens": 0, + "azure-gpt4Tokens": 0, + "azure-gpt4-32kTokens": 0, + "azure-gpt4-turboTokens": 0, + }; + this.keys.push(newKey); + } + this.log.info({ keyCount: this.keys.length }, "Loaded Azure OpenAI keys."); + } + + public init() { + if (config.checkKeys) { + this.checker = new AzureOpenAIKeyChecker( + this.keys, + this.update.bind(this) + ); + this.checker.start(); + } + } + + public list() { + return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); + } + + public get(model: AzureOpenAIModel) { + const neededFamily = getAzureOpenAIModelFamily(model); + const availableKeys = this.keys.filter( + (k) => !k.isDisabled && k.modelFamilies.includes(neededFamily) + ); + if (availableKeys.length === 0) { + throw new Error(`No keys available for model family '${neededFamily}'.`); + } + + // (largely copied from the OpenAI provider, without trial key support) + // Select a key, from highest priority to lowest priority: + // 1. Keys which are not rate limited + // a. If all keys were rate limited recently, select the least-recently + // rate limited key. + // 3. Keys which have not been used in the longest time + + const now = Date.now(); + + const keysByPriority = availableKeys.sort((a, b) => { + const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; + const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; + + if (aRateLimited && !bRateLimited) return 1; + if (!aRateLimited && bRateLimited) return -1; + if (aRateLimited && bRateLimited) { + return a.rateLimitedAt - b.rateLimitedAt; + } + + return a.lastUsed - b.lastUsed; + }); + + const selectedKey = keysByPriority[0]; + selectedKey.lastUsed = now; + this.throttle(selectedKey.hash); + return { ...selectedKey }; + } + + public disable(key: AzureOpenAIKey) { + const keyFromPool = this.keys.find((k) => k.hash === key.hash); + if (!keyFromPool || keyFromPool.isDisabled) return; + keyFromPool.isDisabled = true; + this.log.warn({ key: key.hash }, "Key disabled"); + } + + public update(hash: string, update: Partial) { + const keyFromPool = this.keys.find((k) => k.hash === hash)!; + Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); + } + + public available() { + return this.keys.filter((k) => !k.isDisabled).length; + } + + public incrementUsage(hash: string, model: string, tokens: number) { + const key = this.keys.find((k) => k.hash === hash); + if (!key) return; + key.promptCount++; + key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens; + } + + // TODO: all of this shit is duplicate code + + public getLockoutPeriod() { + const activeKeys = this.keys.filter((k) => !k.isDisabled); + // Don't lock out if there are no keys available or the queue will stall. + // Just let it through so the add-key middleware can throw an error. + if (activeKeys.length === 0) return 0; + + const now = Date.now(); + const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); + const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; + + if (anyNotRateLimited) return 0; + + // If all keys are rate-limited, return time until the first key is ready. + return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); + } + + /** + * This is called when we receive a 429, which means there are already five + * concurrent requests running on this key. We don't have any information on + * when these requests will resolve, so all we can do is wait a bit and try + * again. We will lock the key for 2 seconds after getting a 429 before + * retrying in order to give the other requests a chance to finish. + */ + public markRateLimited(keyHash: string) { + this.log.debug({ key: keyHash }, "Key rate limited"); + const key = this.keys.find((k) => k.hash === keyHash)!; + const now = Date.now(); + key.rateLimitedAt = now; + key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT; + } + + public recheck() { + this.keys.forEach(({ hash }) => + this.update(hash, { lastChecked: 0, isDisabled: false }) + ); + } + + /** + * Applies a short artificial delay to the key upon dequeueing, in order to + * prevent it from being immediately assigned to another request before the + * current one can be dispatched. + **/ + private throttle(hash: string) { + const now = Date.now(); + const key = this.keys.find((k) => k.hash === hash)!; + + const currentRateLimit = key.rateLimitedUntil; + const nextRateLimit = now + KEY_REUSE_DELAY; + + key.rateLimitedAt = now; + key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit); + } +} diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index c647a14..8fa95b1 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -2,6 +2,7 @@ import { OpenAIModel } from "./openai/provider"; import { AnthropicModel } from "./anthropic/provider"; import { GooglePalmModel } from "./palm/provider"; import { AwsBedrockModel } from "./aws/provider"; +import { AzureOpenAIModel } from "./azure/provider"; import { KeyPool } from "./key-pool"; import type { ModelFamily } from "../models"; @@ -13,12 +14,18 @@ export type APIFormat = | "openai-text" | "openai-image"; /** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */ -export type LLMService = "openai" | "anthropic" | "google-palm" | "aws"; +export type LLMService = + | "openai" + | "anthropic" + | "google-palm" + | "aws" + | "azure"; export type Model = | OpenAIModel | AnthropicModel | GooglePalmModel - | AwsBedrockModel; + | AwsBedrockModel + | AzureOpenAIModel; export interface Key { /** The API key itself. Never log this, use `hash` instead. */ @@ -72,3 +79,4 @@ export { AnthropicKey } from "./anthropic/provider"; export { OpenAIKey } from "./openai/provider"; export { GooglePalmKey } from "./palm/provider"; export { AwsBedrockKey } from "./aws/provider"; +export { AzureOpenAIKey } from "./azure/provider"; diff --git a/src/shared/key-management/key-checker-base.ts b/src/shared/key-management/key-checker-base.ts index eb7819f..789a69d 100644 --- a/src/shared/key-management/key-checker-base.ts +++ b/src/shared/key-management/key-checker-base.ts @@ -3,14 +3,17 @@ import { logger } from "../../logger"; import { Key } from "./index"; import { AxiosError } from "axios"; -type KeyCheckerOptions = { +type KeyCheckerOptions = { service: string; keyCheckPeriod: number; minCheckInterval: number; -} + recurringChecksEnabled?: boolean; + updateKey: (hash: string, props: Partial) => void; +}; export abstract class KeyCheckerBase { protected readonly service: string; + protected readonly RECURRING_CHECKS_ENABLED: boolean; /** Minimum time in between any two key checks. */ protected readonly MIN_CHECK_INTERVAL: number; /** @@ -19,16 +22,19 @@ export abstract class KeyCheckerBase { * than this. */ protected readonly KEY_CHECK_PERIOD: number; + protected readonly updateKey: (hash: string, props: Partial) => void; protected readonly keys: TKey[] = []; protected log: pino.Logger; protected timeout?: NodeJS.Timeout; protected lastCheck = 0; - protected constructor(keys: TKey[], opts: KeyCheckerOptions) { + protected constructor(keys: TKey[], opts: KeyCheckerOptions) { const { service, keyCheckPeriod, minCheckInterval } = opts; this.keys = keys; this.KEY_CHECK_PERIOD = keyCheckPeriod; this.MIN_CHECK_INTERVAL = minCheckInterval; + this.RECURRING_CHECKS_ENABLED = opts.recurringChecksEnabled ?? true; + this.updateKey = opts.updateKey; this.service = service; this.log = logger.child({ module: "key-checker", service }); } @@ -52,31 +58,34 @@ export abstract class KeyCheckerBase { * the minimum check interval. */ public scheduleNextCheck() { + // Gives each concurrent check a correlation ID to make logs less confusing. const callId = Math.random().toString(36).slice(2, 8); const timeoutId = this.timeout?.[Symbol.toPrimitive]?.(); const checkLog = this.log.child({ callId, timeoutId }); const enabledKeys = this.keys.filter((key) => !key.isDisabled); - checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check..."); + const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked); + const numEnabled = enabledKeys.length; + const numUnchecked = uncheckedKeys.length; clearTimeout(this.timeout); + this.timeout = undefined; - if (enabledKeys.length === 0) { - checkLog.warn("All keys are disabled. Key checker stopping."); + if (!numEnabled) { + checkLog.warn("All keys are disabled. Stopping."); return; } - // Perform startup checks for any keys that haven't been checked yet. - const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked); - checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys"); - if (uncheckedKeys.length > 0) { - const keysToCheck = uncheckedKeys.slice(0, 12); + checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check..."); + + if (numUnchecked > 0) { + const keycheckBatch = uncheckedKeys.slice(0, 12); this.timeout = setTimeout(async () => { try { - await Promise.all(keysToCheck.map((key) => this.checkKey(key))); + await Promise.all(keycheckBatch.map((key) => this.checkKey(key))); } catch (error) { - this.log.error({ error }, "Error checking one or more keys."); + checkLog.error({ error }, "Error checking one or more keys."); } checkLog.info("Batch complete."); this.scheduleNextCheck(); @@ -84,11 +93,18 @@ export abstract class KeyCheckerBase { checkLog.info( { - batch: keysToCheck.map((k) => k.hash), - remaining: uncheckedKeys.length - keysToCheck.length, + batch: keycheckBatch.map((k) => k.hash), + remaining: uncheckedKeys.length - keycheckBatch.length, newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(), }, - "Scheduled batch check." + "Scheduled batch of initial checks." + ); + return; + } + + if (!this.RECURRING_CHECKS_ENABLED) { + checkLog.info( + "Initial checks complete and recurring checks are disabled for this service. Stopping." ); return; } @@ -106,14 +122,35 @@ export abstract class KeyCheckerBase { ); const delay = nextCheck - Date.now(); - this.timeout = setTimeout(() => this.checkKey(oldestKey), delay); + this.timeout = setTimeout( + () => this.checkKey(oldestKey).then(() => this.scheduleNextCheck()), + delay + ); checkLog.debug( { key: oldestKey.hash, nextCheck: new Date(nextCheck), delay }, - "Scheduled single key check." + "Scheduled next recurring check." ); } - protected abstract checkKey(key: TKey): Promise; + public async checkKey(key: TKey): Promise { + if (key.isDisabled) { + this.log.warn({ key: key.hash }, "Skipping check for disabled key."); + this.scheduleNextCheck(); + return; + } + this.log.debug({ key: key.hash }, "Checking key..."); + + try { + await this.testKeyOrFail(key); + } catch (error) { + this.updateKey(key.hash, {}); + this.handleAxiosError(key, error as AxiosError); + } + + this.lastCheck = Date.now(); + } + + protected abstract testKeyOrFail(key: TKey): Promise; protected abstract handleAxiosError(key: TKey, error: AxiosError): void; -} \ No newline at end of file +} diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 95e7f26..7cf7de2 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -11,6 +11,7 @@ import { GooglePalmKeyProvider } from "./palm/provider"; import { AwsBedrockKeyProvider } from "./aws/provider"; import { ModelFamily } from "../models"; import { assertNever } from "../utils"; +import { AzureOpenAIKeyProvider } from "./azure/provider"; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; @@ -25,6 +26,7 @@ export class KeyPool { this.keyProviders.push(new AnthropicKeyProvider()); this.keyProviders.push(new GooglePalmKeyProvider()); this.keyProviders.push(new AwsBedrockKeyProvider()); + this.keyProviders.push(new AzureOpenAIKeyProvider()); } public init() { @@ -124,6 +126,8 @@ export class KeyPool { // AWS offers models from a few providers // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html return "aws"; + } else if (model.startsWith("azure")) { + return "azure"; } throw new Error(`Unknown service for model '${model}'`); } @@ -142,6 +146,11 @@ export class KeyPool { return "google-palm"; case "aws-claude": return "aws"; + case "azure-turbo": + case "azure-gpt4": + case "azure-gpt4-32k": + case "azure-gpt4-turbo": + return "azure"; default: assertNever(modelFamily); } diff --git a/src/shared/key-management/openai/checker.ts b/src/shared/key-management/openai/checker.ts index d67afbd..4eb4cce 100644 --- a/src/shared/key-management/openai/checker.ts +++ b/src/shared/key-management/openai/checker.ts @@ -27,65 +27,41 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update; export class OpenAIKeyChecker extends KeyCheckerBase { private readonly cloneKey: CloneFn; - private readonly updateKey: UpdateFn; constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) { super(keys, { service: "openai", keyCheckPeriod: KEY_CHECK_PERIOD, minCheckInterval: MIN_CHECK_INTERVAL, + recurringChecksEnabled: false, + updateKey, }); this.cloneKey = cloneFn; - this.updateKey = updateKey; } - protected async checkKey(key: OpenAIKey) { - if (key.isDisabled) { - this.log.warn({ key: key.hash }, "Skipping check for disabled key."); - this.scheduleNextCheck(); - return; - } - - this.log.debug({ key: key.hash }, "Checking key..."); - let isInitialCheck = !key.lastChecked; - try { - // We only need to check for provisioned models on the initial check. - if (isInitialCheck) { - const [provisionedModels, livenessTest] = await Promise.all([ - this.getProvisionedModels(key), - this.testLiveness(key), - this.maybeCreateOrganizationClones(key), - ]); - const updates = { - modelFamilies: provisionedModels, - isTrial: livenessTest.rateLimit <= 250, - }; - this.updateKey(key.hash, updates); - } else { - // No updates needed as models and trial status generally don't change. - const [_livenessTest] = await Promise.all([this.testLiveness(key)]); - this.updateKey(key.hash, {}); - } - this.log.info( - { key: key.hash, models: key.modelFamilies, trial: key.isTrial }, - "Key check complete." - ); - } catch (error) { - // touch the key so we don't check it again for a while + protected async testKeyOrFail(key: OpenAIKey) { + // We only need to check for provisioned models on the initial check. + const isInitialCheck = !key.lastChecked; + if (isInitialCheck) { + const [provisionedModels, livenessTest] = await Promise.all([ + this.getProvisionedModels(key), + this.testLiveness(key), + this.maybeCreateOrganizationClones(key), + ]); + const updates = { + modelFamilies: provisionedModels, + isTrial: livenessTest.rateLimit <= 250, + }; + this.updateKey(key.hash, updates); + } else { + // No updates needed as models and trial status generally don't change. + const [_livenessTest] = await Promise.all([this.testLiveness(key)]); this.updateKey(key.hash, {}); - this.handleAxiosError(key, error as AxiosError); - } - - this.lastCheck = Date.now(); - // Only enqueue the next check if this wasn't a startup check, since those - // are batched together elsewhere. - if (!isInitialCheck) { - this.log.info( - { key: key.hash }, - "Recurring keychecks are disabled, no-op." - ); - // this.scheduleNextCheck(); } + this.log.info( + { key: key.hash, models: key.modelFamilies, trial: key.isTrial }, + "Checked key." + ); } private async getProvisionedModels( @@ -138,6 +114,17 @@ export class OpenAIKeyChecker extends KeyCheckerBase { .filter(({ is_default }) => !is_default) .map(({ id }) => id); this.cloneKey(key.hash, ids); + + // It's possible that the keychecker may be stopped if all non-cloned keys + // happened to be unusable, in which case this clnoe will never be checked + // unless we restart the keychecker. + if (!this.timeout) { + this.log.warn( + { parent: key.hash }, + "Restarting key checker to check cloned keys." + ); + this.scheduleNextCheck(); + } } protected handleAxiosError(key: OpenAIKey, error: AxiosError) { diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index e307043..db2055d 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -217,17 +217,6 @@ export class OpenAIKeyProvider implements KeyProvider { return a.lastUsed - b.lastUsed; }); - // logger.debug( - // { - // byPriority: keysByPriority.map((k) => ({ - // hash: k.hash, - // isRateLimited: now - k.rateLimitedAt < rateLimitThreshold, - // modelFamilies: k.modelFamilies, - // })), - // }, - // "Keys sorted by priority" - // ); - const selectedKey = keysByPriority[0]; selectedKey.lastUsed = now; this.throttle(selectedKey.hash); diff --git a/src/shared/models.ts b/src/shared/models.ts index 991499e..c91ab03 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -2,15 +2,25 @@ import pino from "pino"; -export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e"; +export type OpenAIModelFamily = + | "turbo" + | "gpt4" + | "gpt4-32k" + | "gpt4-turbo" + | "dall-e"; export type AnthropicModelFamily = "claude"; export type GooglePalmModelFamily = "bison"; export type AwsBedrockModelFamily = "aws-claude"; +export type AzureOpenAIModelFamily = `azure-${Exclude< + OpenAIModelFamily, + "dall-e" +>}`; export type ModelFamily = | OpenAIModelFamily | AnthropicModelFamily | GooglePalmModelFamily - | AwsBedrockModelFamily; + | AwsBedrockModelFamily + | AzureOpenAIModelFamily; export const MODEL_FAMILIES = (( arr: A & ([ModelFamily] extends [A[number]] ? unknown : never) @@ -23,6 +33,10 @@ export const MODEL_FAMILIES = (( "claude", "bison", "aws-claude", + "azure-turbo", + "azure-gpt4", + "azure-gpt4-32k", + "azure-gpt4-turbo", ] as const); export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { @@ -64,6 +78,24 @@ export function getAwsBedrockModelFamily(_model: string): ModelFamily { return "aws-claude"; } +export function getAzureOpenAIModelFamily( + model: string, + defaultFamily: AzureOpenAIModelFamily = "azure-gpt4" +): AzureOpenAIModelFamily { + // Azure model names omit periods. addAzureKey also prepends "azure-" to the + // model name to route the request the correct keyprovider, so we need to + // remove that as well. + const modified = model + .replace("gpt-35-turbo", "gpt-3.5-turbo") + .replace("azure-", ""); + for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) { + if (modified.match(regex)) { + return `azure-${family}` as AzureOpenAIModelFamily; + } + } + return defaultFamily; +} + export function assertIsKnownModelFamily( modelFamily: string ): asserts modelFamily is ModelFamily { diff --git a/src/shared/users/user-store.ts b/src/shared/users/user-store.ts index 8fbdb58..e42d56b 100644 --- a/src/shared/users/user-store.ts +++ b/src/shared/users/user-store.ts @@ -12,6 +12,7 @@ import schedule from "node-schedule"; import { v4 as uuid } from "uuid"; import { config, getFirebaseApp } from "../../config"; import { + getAzureOpenAIModelFamily, getClaudeModelFamily, getGooglePalmModelFamily, getOpenAIModelFamily, @@ -34,6 +35,10 @@ const INITIAL_TOKENS: Required = { claude: 0, bison: 0, "aws-claude": 0, + "azure-turbo": 0, + "azure-gpt4": 0, + "azure-gpt4-turbo": 0, + "azure-gpt4-32k": 0, }; const users: Map = new Map(); @@ -382,6 +387,9 @@ function getModelFamilyForQuotaUsage( model: string, api: APIFormat ): ModelFamily { + // TODO: this seems incorrect + if (model.includes("azure")) return getAzureOpenAIModelFamily(model); + switch (api) { case "openai": case "openai-text":