Azure OpenAI suport (khanon/oai-reverse-proxy!48)
This commit is contained in:
parent
cd1b9d0e0c
commit
fbdea30264
|
@ -34,10 +34,10 @@
|
|||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
||||
# generation, uncomment the line below and add 'dall-e' to the list.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
|
||||
# URLs from which requests will be blocked.
|
||||
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
||||
|
@ -114,6 +114,8 @@ OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
|||
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
# See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS.
|
||||
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
||||
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
|
||||
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
|
||||
|
||||
# With proxy_key gatekeeper, the password users must provide to access the API.
|
||||
# PROXY_KEY=your-secret-key
|
||||
|
|
|
@ -5,3 +5,4 @@
|
|||
build
|
||||
greeting.md
|
||||
node_modules
|
||||
http-client.private.env.json
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# Configuring the proxy for Azure
|
||||
|
||||
The proxy supports Azure OpenAI Service via the `/proxy/azure/openai` endpoint. The process of setting it up is slightly different from regular OpenAI.
|
||||
|
||||
- [Setting keys](#setting-keys)
|
||||
- [Model assignment](#model-assignment)
|
||||
|
||||
## Setting keys
|
||||
|
||||
Use the `AZURE_CREDENTIALS` environment variable to set the Azure API keys.
|
||||
|
||||
Like other APIs, you can provide multiple keys separated by commas. Each Azure key, however, is a set of values including the Resource Name, Deployment ID, and API key. These are separated by a colon (`:`).
|
||||
|
||||
For example:
|
||||
```
|
||||
AZURE_CREDENTIALS=contoso-ml:gpt4-8k:0123456789abcdef0123456789abcdef,northwind-corp:testdeployment:0123456789abcdef0123456789abcdef
|
||||
```
|
||||
|
||||
## Model assignment
|
||||
Note that each Azure deployment is assigned a model when you create it in the Microsoft Cognitive Services portal. If you want to use a different model, you'll need to create a new deployment, and therefore a new key to be added to the AZURE_CREDENTIALS environment variable. Each credential only grants access to one model.
|
||||
|
||||
### Supported model IDs
|
||||
Users can send normal OpenAI model IDs to the proxy to invoke the corresponding models. For the most part they work the same with Azure. GPT-3.5 Turbo has an ID of "gpt-35-turbo" because Azure doesn't allow periods in model names, but the proxy should automatically convert this to the correct ID.
|
||||
|
||||
As noted above, you can only use model IDs for which a deployment has been created and added to the proxy.
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"dev": {
|
||||
"proxy-host": "http://localhost:7860",
|
||||
"oai-key-1": "override in http-client.private.env.json",
|
||||
"proxy-key": "override in http-client.private.env.json",
|
||||
"azu-resource-name": "override in http-client.private.env.json",
|
||||
"azu-deployment-id": "override in http-client.private.env.json"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,248 @@
|
|||
# OAI Reverse Proxy
|
||||
|
||||
###
|
||||
# @name OpenAI -- Chat Completions
|
||||
POST https://api.openai.com/v1/chat/completions
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 30,
|
||||
"stream": false,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test prompt."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name OpenAI -- Text Completions
|
||||
POST https://api.openai.com/v1/completions
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"max_tokens": 30,
|
||||
"stream": false,
|
||||
"prompt": "This is a test prompt where"
|
||||
}
|
||||
|
||||
###
|
||||
# @name OpenAI -- Create Embedding
|
||||
POST https://api.openai.com/v1/embeddings
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "This is a test embedding input."
|
||||
}
|
||||
|
||||
###
|
||||
# @name OpenAI -- Get Organizations
|
||||
GET https://api.openai.com/v1/organizations
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
|
||||
###
|
||||
# @name OpenAI -- Get Models
|
||||
GET https://api.openai.com/v1/models
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
|
||||
###
|
||||
# @name Azure OpenAI -- Chat Completions
|
||||
POST https://{{azu-resource-name}}.openai.azure.com/openai/deployments/{{azu-deployment-id}}/chat/completions?api-version=2023-09-01-preview
|
||||
api-key: {{azu-key-1}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"max_tokens": 1,
|
||||
"stream": false,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test prompt."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Get Models
|
||||
GET {{proxy-host}}/proxy/openai/v1/models
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Native Chat Completions
|
||||
POST {{proxy-host}}/proxy/openai/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 20,
|
||||
"stream": true,
|
||||
"temperature": 1,
|
||||
"seed": 123,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "phrase one"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Native Text Completions
|
||||
POST {{proxy-host}}/proxy/openai/v1/turbo-instruct/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"max_tokens": 20,
|
||||
"temperature": 0,
|
||||
"prompt": "Genshin Impact is a game about",
|
||||
"stream": false
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Chat-to-Text API Translation
|
||||
# Accepts a chat completion request and reformats it to work with the text completion API. `model` is ignored.
|
||||
POST {{proxy-host}}/proxy/openai/turbo-instruct/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 20,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the name of the fourth president of the united states?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "That would be George Washington."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I don't think that's right..."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Create Embedding
|
||||
POST {{proxy-host}}/proxy/openai/embeddings
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "This is a test embedding input."
|
||||
}
|
||||
|
||||
|
||||
###
|
||||
# @name Proxy / Anthropic -- Native Completion (old API)
|
||||
POST {{proxy-host}}/proxy/anthropic/v1/complete
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
anthropic-version: 2023-01-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v1.3",
|
||||
"max_tokens_to_sample": 20,
|
||||
"temperature": 0.2,
|
||||
"stream": true,
|
||||
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Anthropic -- Native Completion (2023-06-01 API)
|
||||
POST {{proxy-host}}/proxy/anthropic/v1/complete
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
anthropic-version: 2023-06-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v1.3",
|
||||
"max_tokens_to_sample": 20,
|
||||
"temperature": 0.2,
|
||||
"stream": true,
|
||||
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Anthropic -- OpenAI-to-Anthropic API Translation
|
||||
POST {{proxy-host}}/proxy/anthropic/v1/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
#anthropic-version: 2023-06-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 20,
|
||||
"stream": false,
|
||||
"temperature": 0,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is genshin impact"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / AWS Claude -- Native Completion
|
||||
POST {{proxy-host}}/proxy/aws/claude/v1/complete
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
anthropic-version: 2023-01-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v2",
|
||||
"max_tokens_to_sample": 10,
|
||||
"temperature": 0,
|
||||
"stream": true,
|
||||
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / AWS Claude -- OpenAI-to-Anthropic API Translation
|
||||
POST {{proxy-host}}/proxy/aws/claude/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 50,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is genshin impact?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation
|
||||
POST {{proxy-host}}/proxy/google-palm/v1/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 42,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi what is the name of the fourth president of the united states?"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
$NumThreads = 10
|
||||
|
||||
$runspacePool = [runspacefactory]::CreateRunspacePool(1, $NumThreads)
|
||||
$runspacePool.Open()
|
||||
$runspaces = @()
|
||||
|
||||
$headers = @{
|
||||
"Authorization" = "Bearer test"
|
||||
"anthropic-version" = "2023-01-01"
|
||||
"Content-Type" = "application/json"
|
||||
}
|
||||
|
||||
$payload = @{
|
||||
model = "claude-v2"
|
||||
max_tokens_to_sample = 40
|
||||
temperature = 0
|
||||
stream = $true
|
||||
prompt = "Test prompt, please reply with lorem ipsum`n`n:Assistant:"
|
||||
} | ConvertTo-Json
|
||||
|
||||
for ($i = 1; $i -le $NumThreads; $i++) {
|
||||
Write-Host "Starting thread $i"
|
||||
$runspace = [powershell]::Create()
|
||||
$runspace.AddScript({
|
||||
param($i, $headers, $payload)
|
||||
$response = Invoke-WebRequest -Uri "http://localhost:7860/proxy/aws/claude/v1/complete" -Method Post -Headers $headers -Body $payload
|
||||
Write-Host "Response from server: $($response.StatusCode)"
|
||||
}).AddArgument($i).AddArgument($headers).AddArgument($payload)
|
||||
|
||||
$runspace.RunspacePool = $runspacePool
|
||||
$runspaces += [PSCustomObject]@{ Pipe = $runspace; Status = $runspace.BeginInvoke() }
|
||||
}
|
||||
|
||||
$runspaces | ForEach-Object {
|
||||
$_.Pipe.EndInvoke($_.Status)
|
||||
$_.Pipe.Dispose()
|
||||
}
|
||||
|
||||
$runspacePool.Close()
|
||||
$runspacePool.Dispose()
|
|
@ -33,6 +33,17 @@ type Config = {
|
|||
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
|
||||
*/
|
||||
awsCredentials?: string;
|
||||
/**
|
||||
* Comma-delimited list of Azure OpenAI credentials. Each credential item
|
||||
* should be a colon-delimited list of Azure resource name, deployment ID, and
|
||||
* API key.
|
||||
*
|
||||
* The resource name is the subdomain in your Azure OpenAI deployment's URL,
|
||||
* e.g. `https://resource-name.openai.azure.com
|
||||
*
|
||||
* @example `AZURE_CREDENTIALS=resource_name_1:deployment_id_1:api_key_1,resource_name_2:deployment_id_2:api_key_2`
|
||||
*/
|
||||
azureCredentials?: string;
|
||||
/**
|
||||
* The proxy key to require for requests. Only applicable if the user
|
||||
* management mode is set to 'proxy_key', and required if so.
|
||||
|
@ -188,6 +199,7 @@ export const config: Config = {
|
|||
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
||||
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
|
||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
||||
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
|
||||
|
@ -219,6 +231,10 @@ export const config: Config = {
|
|||
"claude",
|
||||
"bison",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4-32k",
|
||||
]),
|
||||
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
|
||||
rejectMessage: getEnvWithDefault(
|
||||
|
@ -352,6 +368,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
|||
"anthropicKey",
|
||||
"googlePalmKey",
|
||||
"awsCredentials",
|
||||
"azureCredentials",
|
||||
"proxyKey",
|
||||
"adminKey",
|
||||
"rejectPhrases",
|
||||
|
@ -369,6 +386,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
|||
"useInsecureCookies",
|
||||
"staticServiceInfo",
|
||||
"checkKeys",
|
||||
"allowedModelFamilies",
|
||||
];
|
||||
|
||||
const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>;
|
||||
|
@ -417,6 +435,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
|||
"ANTHROPIC_KEY",
|
||||
"GOOGLE_PALM_KEY",
|
||||
"AWS_CREDENTIALS",
|
||||
"AZURE_CREDENTIALS",
|
||||
].includes(String(env))
|
||||
) {
|
||||
return value as unknown as T;
|
||||
|
|
137
src/info-page.ts
137
src/info-page.ts
|
@ -1,3 +1,4 @@
|
|||
/** This whole module really sucks */
|
||||
import fs from "fs";
|
||||
import { Request, Response } from "express";
|
||||
import showdown from "showdown";
|
||||
|
@ -5,11 +6,16 @@ import { config, listConfig } from "./config";
|
|||
import {
|
||||
AnthropicKey,
|
||||
AwsBedrockKey,
|
||||
AzureOpenAIKey,
|
||||
GooglePalmKey,
|
||||
keyPool,
|
||||
OpenAIKey,
|
||||
} from "./shared/key-management";
|
||||
import { ModelFamily, OpenAIModelFamily } from "./shared/models";
|
||||
import {
|
||||
AzureOpenAIModelFamily,
|
||||
ModelFamily,
|
||||
OpenAIModelFamily,
|
||||
} from "./shared/models";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
|
||||
|
@ -23,6 +29,8 @@ let infoPageLastUpdated = 0;
|
|||
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
|
||||
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
|
||||
k.service === "openai";
|
||||
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
|
||||
k.service === "azure";
|
||||
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||
k.service === "anthropic";
|
||||
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
|
||||
|
@ -48,6 +56,7 @@ type ServiceAggregates = {
|
|||
anthropicKeys?: number;
|
||||
palmKeys?: number;
|
||||
awsKeys?: number;
|
||||
azureKeys?: number;
|
||||
proompts: number;
|
||||
tokens: number;
|
||||
tokenCost: number;
|
||||
|
@ -62,17 +71,15 @@ const serviceStats = new Map<keyof ServiceAggregates, number>();
|
|||
|
||||
export const handleInfoPage = (req: Request, res: Response) => {
|
||||
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
|
||||
res.send(infoPageHtml);
|
||||
return;
|
||||
return res.send(infoPageHtml);
|
||||
}
|
||||
|
||||
// Sometimes huggingface doesn't send the host header and makes us guess.
|
||||
const baseUrl =
|
||||
process.env.SPACE_ID && !req.get("host")?.includes("hf.space")
|
||||
? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID)
|
||||
: req.protocol + "://" + req.get("host");
|
||||
|
||||
infoPageHtml = buildInfoPageHtml(baseUrl);
|
||||
infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy");
|
||||
infoPageLastUpdated = Date.now();
|
||||
|
||||
res.send(infoPageHtml);
|
||||
|
@ -95,6 +102,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
|||
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
|
||||
const palmKeys = serviceStats.get("palmKeys") || 0;
|
||||
const awsKeys = serviceStats.get("awsKeys") || 0;
|
||||
const azureKeys = serviceStats.get("azureKeys") || 0;
|
||||
const proompts = serviceStats.get("proompts") || 0;
|
||||
const tokens = serviceStats.get("tokens") || 0;
|
||||
const tokenCost = serviceStats.get("tokenCost") || 0;
|
||||
|
@ -102,16 +110,15 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
|||
const allowDalle = config.allowedModelFamilies.includes("dall-e");
|
||||
|
||||
const endpoints = {
|
||||
...(openaiKeys ? { openai: baseUrl + "/proxy/openai" } : {}),
|
||||
...(openaiKeys
|
||||
? { ["openai2"]: baseUrl + "/proxy/openai/turbo-instruct" }
|
||||
: {}),
|
||||
...(openaiKeys ? { openai: baseUrl + "/openai" } : {}),
|
||||
...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}),
|
||||
...(openaiKeys && allowDalle
|
||||
? { ["openai-image"]: baseUrl + "/proxy/openai-image" }
|
||||
? { ["openai-image"]: baseUrl + "/openai-image" }
|
||||
: {}),
|
||||
...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}),
|
||||
...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}),
|
||||
...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}),
|
||||
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
|
||||
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
|
||||
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
|
||||
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
|
||||
};
|
||||
|
||||
const stats = {
|
||||
|
@ -120,13 +127,17 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
|||
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
||||
};
|
||||
|
||||
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys };
|
||||
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
|
||||
for (const key of Object.keys(keyInfo)) {
|
||||
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
|
||||
}
|
||||
|
||||
const providerInfo = {
|
||||
...(openaiKeys ? getOpenAIInfo() : {}),
|
||||
...(anthropicKeys ? getAnthropicInfo() : {}),
|
||||
...(palmKeys ? { "palm-bison": getPalmInfo() } : {}),
|
||||
...(awsKeys ? { "aws-claude": getAwsInfo() } : {}),
|
||||
...(palmKeys ? getPalmInfo() : {}),
|
||||
...(awsKeys ? getAwsInfo() : {}),
|
||||
...(azureKeys ? getAzureInfo() : {}),
|
||||
};
|
||||
|
||||
if (hideFullInfo) {
|
||||
|
@ -188,6 +199,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
|
||||
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
|
||||
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
let sumTokens = 0;
|
||||
let sumCost = 0;
|
||||
|
@ -201,17 +213,26 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
|
||||
// Technically this would not account for keys that have tokens recorded
|
||||
// on models they aren't provisioned for, but that would be strange
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "azure":
|
||||
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "anthropic": {
|
||||
|
@ -381,11 +402,13 @@ function getPalmInfo() {
|
|||
const cost = getTokenCostUsd("bison", tokens);
|
||||
|
||||
return {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: bisonInfo.active,
|
||||
revokedKeys: bisonInfo.revoked,
|
||||
proomptersInQueue: bisonInfo.queued,
|
||||
estimatedQueueTime: bisonInfo.queueTime,
|
||||
bison: {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: bisonInfo.active,
|
||||
revokedKeys: bisonInfo.revoked,
|
||||
proomptersInQueue: bisonInfo.queued,
|
||||
estimatedQueueTime: bisonInfo.queueTime,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -408,15 +431,59 @@ function getAwsInfo() {
|
|||
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
|
||||
|
||||
return {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: awsInfo.active,
|
||||
revokedKeys: awsInfo.revoked,
|
||||
proomptersInQueue: awsInfo.queued,
|
||||
estimatedQueueTime: awsInfo.queueTime,
|
||||
...(logged > 0 ? { privacy: logMsg } : {}),
|
||||
"aws-claude": {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: awsInfo.active,
|
||||
revokedKeys: awsInfo.revoked,
|
||||
proomptersInQueue: awsInfo.queued,
|
||||
estimatedQueueTime: awsInfo.queueTime,
|
||||
...(logged > 0 ? { privacy: logMsg } : {}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function getAzureInfo() {
|
||||
const azureFamilies = [
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4-32k",
|
||||
] as const;
|
||||
|
||||
const azureInfo: {
|
||||
[modelFamily in AzureOpenAIModelFamily]?: {
|
||||
usage?: string;
|
||||
activeKeys: number;
|
||||
revokedKeys?: number;
|
||||
proomptersInQueue?: number;
|
||||
estimatedQueueTime?: string;
|
||||
};
|
||||
} = {};
|
||||
for (const family of azureFamilies) {
|
||||
const familyAllowed = config.allowedModelFamilies.includes(family);
|
||||
const activeKeys = modelStats.get(`${family}__active`) || 0;
|
||||
|
||||
if (!familyAllowed || activeKeys === 0) continue;
|
||||
|
||||
azureInfo[family] = {
|
||||
activeKeys,
|
||||
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation(family);
|
||||
azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue;
|
||||
azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime;
|
||||
|
||||
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(family, tokens);
|
||||
azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString(
|
||||
cost
|
||||
)}`;
|
||||
}
|
||||
|
||||
return azureInfo;
|
||||
}
|
||||
|
||||
const customGreeting = fs.existsSync("greeting.md")
|
||||
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
|
||||
: "";
|
||||
|
@ -430,10 +497,10 @@ function buildInfoPageHeader(converter: showdown.Converter, title: string) {
|
|||
let infoBody = `<!-- Header for Showdown's parser, don't remove this line -->
|
||||
# ${title}`;
|
||||
if (config.promptLogging) {
|
||||
infoBody += `\n## Prompt logging is enabled!
|
||||
The server operator has enabled prompt logging. The prompts you send to this proxy and the AI responses you receive may be saved.
|
||||
infoBody += `\n## Prompt Logging Enabled
|
||||
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
|
||||
|
||||
Logs are anonymous and do not contain IP addresses or timestamps. [You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/prompt-logging/index.ts).
|
||||
[You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/shared/prompt-logging/index.ts).
|
||||
|
||||
**If you are uncomfortable with this, don't send prompts to this proxy!**`;
|
||||
}
|
||||
|
@ -570,8 +637,6 @@ function escapeHtml(unsafe: string) {
|
|||
}
|
||||
|
||||
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
|
||||
// Huggingface broke their amazon elb config and no longer sends the
|
||||
// x-forwarded-host header. This is a workaround.
|
||||
try {
|
||||
const [username, spacename] = spaceId.split("/");
|
||||
return `https://${username}-${spacename.replace(/_/g, "-")}.hf.space`;
|
||||
|
|
|
@ -11,7 +11,7 @@ import {
|
|||
createPreprocessorMiddleware,
|
||||
stripHeaders,
|
||||
signAwsRequest,
|
||||
finalizeAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
blockZoomerOrigins,
|
||||
} from "./middleware/request";
|
||||
|
@ -30,7 +30,11 @@ const getModelsResponse = () => {
|
|||
|
||||
if (!config.awsCredentials) return { object: "list", data: [] };
|
||||
|
||||
const variants = ["anthropic.claude-v1", "anthropic.claude-v2"];
|
||||
const variants = [
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-v2:1",
|
||||
];
|
||||
|
||||
const models = variants.map((id) => ({
|
||||
id,
|
||||
|
@ -134,7 +138,7 @@ const awsProxy = createQueueMiddleware({
|
|||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
|
@ -183,7 +187,7 @@ function maybeReassignModel(req: Request) {
|
|||
req.body.model = "anthropic.claude-v1";
|
||||
} else {
|
||||
// User's client requested v2 or possibly some OpenAI model, default to v2
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
req.body.model = "anthropic.claude-v2:1";
|
||||
}
|
||||
// TODO: Handle claude-instant
|
||||
}
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
import { RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import {
|
||||
ModelFamily,
|
||||
AzureOpenAIModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
} from "../shared/models";
|
||||
import { logger } from "../logger";
|
||||
import { KNOWN_OPENAI_MODELS } from "./openai";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
limitCompletions,
|
||||
stripHeaders,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { addAzureKey } from "./middleware/request/add-azure-key";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
function getModelsResponse() {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
let available = new Set<AzureOpenAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "azure") continue;
|
||||
key.modelFamilies.forEach((family) =>
|
||||
available.add(family as AzureOpenAIModelFamily)
|
||||
);
|
||||
}
|
||||
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||
|
||||
const models = KNOWN_OPENAI_MODELS.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "azure",
|
||||
permission: [
|
||||
{
|
||||
id: "modelperm-" + id,
|
||||
object: "model_permission",
|
||||
created: new Date().getTime(),
|
||||
organization: "*",
|
||||
group: null,
|
||||
is_blocking: false,
|
||||
},
|
||||
],
|
||||
root: id,
|
||||
parent: null,
|
||||
})).filter((model) => available.has(getAzureOpenAIModelFamily(model.id)));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
const azureOpenAIProxy = createQueueMiddleware({
|
||||
beforeProxy: addAzureKey,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "will be set by router",
|
||||
router: (req) => {
|
||||
if (!req.signedRequest) throw new Error("signedRequest not set");
|
||||
const { hostname, path } = req.signedRequest;
|
||||
return `https://${hostname}${path}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
limitCompletions,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeSignedRequest,
|
||||
],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const azureOpenAIRouter = Router();
|
||||
azureOpenAIRouter.get("/v1/models", handleModelRequest);
|
||||
azureOpenAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "openai",
|
||||
outApi: "openai",
|
||||
service: "azure",
|
||||
}),
|
||||
azureOpenAIProxy
|
||||
);
|
||||
|
||||
export const azure = azureOpenAIRouter;
|
|
@ -59,7 +59,7 @@ export function writeErrorResponse(
|
|||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
if (req.tokenizerInfo && errorPayload.error) {
|
||||
if (req.tokenizerInfo && typeof errorPayload.error === "object") {
|
||||
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
res.status(statusCode).json(errorPayload);
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import { AzureOpenAIKey, keyPool } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
export const addAzureKey: RequestPreprocessor = (req) => {
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
|
||||
const serviceValid = req.service === "azure";
|
||||
if (!apisValid || !serviceValid) {
|
||||
throw new Error("addAzureKey called on invalid request");
|
||||
}
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
const model = req.body.model.startsWith("azure-")
|
||||
? req.body.model
|
||||
: `azure-${req.body.model}`;
|
||||
|
||||
req.key = keyPool.get(model);
|
||||
req.body.model = model;
|
||||
|
||||
req.log.info(
|
||||
{ key: req.key.hash, model },
|
||||
"Assigned Azure OpenAI key to request"
|
||||
);
|
||||
|
||||
const cred = req.key as AzureOpenAIKey;
|
||||
const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred);
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: `${resourceName}.openai.azure.com`,
|
||||
path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`,
|
||||
headers: {
|
||||
["host"]: `${resourceName}.openai.azure.com`,
|
||||
["content-type"]: "application/json",
|
||||
["api-key"]: apiKey,
|
||||
},
|
||||
body: JSON.stringify(req.body),
|
||||
};
|
||||
};
|
||||
|
||||
function getCredentialsFromKey(key: AzureOpenAIKey) {
|
||||
const [resourceName, deploymentId, apiKey] = key.key.split(":");
|
||||
if (!resourceName || !deploymentId || !apiKey) {
|
||||
throw new Error("Assigned Azure OpenAI key is not in the correct format.");
|
||||
}
|
||||
return { resourceName, deploymentId, apiKey };
|
||||
}
|
|
@ -80,6 +80,10 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
|||
`?key=${assignedKey.key}`
|
||||
);
|
||||
break;
|
||||
case "azure":
|
||||
const azureKey = assignedKey.key;
|
||||
proxyReq.setHeader("api-key", azureKey);
|
||||
break;
|
||||
case "aws":
|
||||
throw new Error(
|
||||
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import type { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* For AWS requests, the body is signed earlier in the request pipeline, before
|
||||
* the proxy middleware. This function just assigns the path and headers to the
|
||||
* proxy request.
|
||||
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
|
||||
* before the proxy middleware. This function just assigns the path and headers
|
||||
* to the proxy request.
|
||||
*/
|
||||
export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
if (!req.signedRequest) {
|
||||
throw new Error("Expected req.signedRequest to be set");
|
||||
}
|
|
@ -22,7 +22,7 @@ export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
|
|||
export { addAnthropicPreamble } from "./add-anthropic-preamble";
|
||||
export { blockZoomerOrigins } from "./block-zoomer-origins";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
export { finalizeAwsRequest } from "./finalize-aws-request";
|
||||
export { finalizeSignedRequest } from "./finalize-signed-request";
|
||||
export { limitCompletions } from "./limit-completions";
|
||||
export { stripHeaders } from "./strip-headers";
|
||||
|
||||
|
|
|
@ -289,15 +289,17 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
switch (service) {
|
||||
case "openai":
|
||||
case "google-palm":
|
||||
if (errorPayload.error?.code === "content_policy_violation") {
|
||||
errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`;
|
||||
case "azure":
|
||||
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||
if (filteredCodes.includes(errorPayload.error?.code)) {
|
||||
errorPayload.proxy_note = `Request was filtered by the upstream API's content moderation system. Modify your prompt and try again.`;
|
||||
refundLastAttempt(req);
|
||||
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
|
||||
// For some reason, some models return this 400 error instead of the
|
||||
// same 429 billing error that other models return.
|
||||
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||
} else {
|
||||
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
|
||||
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
|
||||
}
|
||||
break;
|
||||
case "anthropic":
|
||||
|
@ -342,7 +344,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
handleAwsRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "google-palm":
|
||||
throw new Error("Rate limit handling not implemented for PaLM");
|
||||
case "azure":
|
||||
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
|
@ -369,6 +373,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
case "azure":
|
||||
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ type SSEMessageTransformerOptions = TransformOptions & {
|
|||
export class SSEMessageTransformer extends Transform {
|
||||
private lastPosition: number;
|
||||
private msgCount: number;
|
||||
private readonly inputFormat: APIFormat;
|
||||
private readonly transformFn: StreamingCompletionTransformer;
|
||||
private readonly log;
|
||||
private readonly fallbackId: string;
|
||||
|
@ -42,6 +43,7 @@ export class SSEMessageTransformer extends Transform {
|
|||
options.inputFormat,
|
||||
options.inputApiVersion
|
||||
);
|
||||
this.inputFormat = options.inputFormat;
|
||||
this.fallbackId = options.requestId;
|
||||
this.fallbackModel = options.requestedModel;
|
||||
this.log.debug(
|
||||
|
@ -67,6 +69,17 @@ export class SSEMessageTransformer extends Transform {
|
|||
});
|
||||
this.lastPosition = newPosition;
|
||||
|
||||
// Special case for Azure OpenAI, which is 99% the same as OpenAI but
|
||||
// sometimes emits an extra event at the beginning of the stream with the
|
||||
// content moderation system's response to the prompt. A lot of frontends
|
||||
// don't expect this and neither does our event aggregator so we drop it.
|
||||
if (this.inputFormat === "openai" && this.msgCount <= 1) {
|
||||
if (originalMessage.includes("prompt_filter_results")) {
|
||||
this.log.debug("Dropping Azure OpenAI content moderation SSE event");
|
||||
return callback();
|
||||
}
|
||||
}
|
||||
|
||||
this.emit("originalMessage", originalMessage);
|
||||
|
||||
// Some events may not be transformed, e.g. ping events
|
||||
|
|
|
@ -24,7 +24,7 @@ import {
|
|||
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
|
||||
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
const KNOWN_MODELS = [
|
||||
export const KNOWN_OPENAI_MODELS = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
|
@ -46,7 +46,7 @@ const KNOWN_MODELS = [
|
|||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
export function generateModelList(models = KNOWN_MODELS) {
|
||||
export function generateModelList(models = KNOWN_OPENAI_MODELS) {
|
||||
let available = new Set<OpenAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "openai") continue;
|
||||
|
|
|
@ -26,6 +26,7 @@ import { assertNever } from "../shared/utils";
|
|||
import { logger } from "../logger";
|
||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||
import { RequestPreprocessor } from "./middleware/request";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
|
@ -34,7 +35,7 @@ const log = logger.child({ module: "request-queue" });
|
|||
const AGNAI_CONCURRENCY_LIMIT = 5;
|
||||
/** Maximum number of queue slots for individual users. */
|
||||
const USER_CONCURRENCY_LIMIT = 1;
|
||||
const MIN_HEARTBEAT_SIZE = 512;
|
||||
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
|
||||
const MAX_HEARTBEAT_SIZE =
|
||||
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
|
||||
const HEARTBEAT_INTERVAL =
|
||||
|
@ -358,12 +359,16 @@ export function createQueueMiddleware({
|
|||
return (req, res, next) => {
|
||||
req.proceed = async () => {
|
||||
if (beforeProxy) {
|
||||
// Hack to let us run asynchronous middleware before the
|
||||
// http-proxy-middleware handler. This is used to sign AWS requests
|
||||
// before they are proxied, as the signing is asynchronous.
|
||||
// Unlike RequestPreprocessors, this runs every time the request is
|
||||
// dequeued, not just the first time.
|
||||
await beforeProxy(req);
|
||||
try {
|
||||
// Hack to let us run asynchronous middleware before the
|
||||
// http-proxy-middleware handler. This is used to sign AWS requests
|
||||
// before they are proxied, as the signing is asynchronous.
|
||||
// Unlike RequestPreprocessors, this runs every time the request is
|
||||
// dequeued, not just the first time.
|
||||
await beforeProxy(req);
|
||||
} catch (err) {
|
||||
return handleProxyError(err, req, res);
|
||||
}
|
||||
}
|
||||
proxyMiddleware(req, res, next);
|
||||
};
|
||||
|
|
|
@ -6,6 +6,7 @@ import { openaiImage } from "./openai-image";
|
|||
import { anthropic } from "./anthropic";
|
||||
import { googlePalm } from "./palm";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
|
||||
const proxyRouter = express.Router();
|
||||
proxyRouter.use((req, _res, next) => {
|
||||
|
@ -32,6 +33,7 @@ proxyRouter.use("/openai-image", addV1, openaiImage);
|
|||
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||
proxyRouter.use("/google-palm", addV1, googlePalm);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
// Redirect browser requests to the homepage.
|
||||
proxyRouter.get("*", (req, res, next) => {
|
||||
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
||||
|
|
|
@ -26,46 +26,23 @@ type AnthropicAPIError = {
|
|||
type UpdateFn = typeof AnthropicKeyProvider.prototype.update;
|
||||
|
||||
export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: AnthropicKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "anthropic",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
updateKey,
|
||||
});
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async checkKey(key: AnthropicKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
|
||||
const updates = { isPozzed: pozzed };
|
||||
this.updateKey(key.hash, updates);
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies },
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
// touch the key so we don't check it again for a while
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
protected async testKeyOrFail(key: AnthropicKey) {
|
||||
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
|
||||
const updates = { isPozzed: pozzed };
|
||||
this.updateKey(key.hash, updates);
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies },
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
|
||||
|
@ -84,6 +61,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
|||
{ key: key.hash, error: error.message },
|
||||
"Key is rate limited. Rechecking in 10 seconds."
|
||||
);
|
||||
0;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
break;
|
||||
|
|
|
@ -32,58 +32,36 @@ type GetLoggingConfigResponse = {
|
|||
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
|
||||
|
||||
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "aws",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
updateKey,
|
||||
});
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async checkKey(key: AwsBedrockKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
protected async testKeyOrFail(key: AwsBedrockKey) {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
const modelChecks: Promise<unknown>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
|
||||
}
|
||||
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
const modelChecks: Promise<unknown>[] = [];
|
||||
if (isInitialCheck) {
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
|
||||
}
|
||||
await Promise.all(modelChecks);
|
||||
await this.checkLoggingConfiguration(key);
|
||||
|
||||
await Promise.all(modelChecks);
|
||||
await this.checkLoggingConfiguration(key);
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
models: key.modelFamilies,
|
||||
logged: key.awsLoggingStatus,
|
||||
},
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {});
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
models: key.modelFamilies,
|
||||
logged: key.awsLoggingStatus,
|
||||
},
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
import axios, { AxiosError } from "axios";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
|
||||
const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com";
|
||||
const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) =>
|
||||
`https://${AZURE_HOST.replace(
|
||||
"%RESOURCE_NAME%",
|
||||
resourceName
|
||||
)}/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`;
|
||||
|
||||
type AzureError = {
|
||||
error: {
|
||||
message: string;
|
||||
type: string | null;
|
||||
param: string;
|
||||
code: string;
|
||||
status: number;
|
||||
};
|
||||
};
|
||||
type UpdateFn = typeof AzureOpenAIKeyProvider.prototype.update;
|
||||
|
||||
export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
||||
constructor(keys: AzureOpenAIKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "azure",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: AzureOpenAIKey) {
|
||||
const model = await this.testModel(key);
|
||||
this.log.info(
|
||||
{ key: key.hash, deploymentModel: model },
|
||||
"Checked key."
|
||||
);
|
||||
this.updateKey(key.hash, { modelFamilies: [model] });
|
||||
}
|
||||
|
||||
// provided api-key header isn't valid (401)
|
||||
// {
|
||||
// "error": {
|
||||
// "code": "401",
|
||||
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
|
||||
// }
|
||||
// }
|
||||
|
||||
// api key correct but deployment id is wrong (404)
|
||||
// {
|
||||
// "error": {
|
||||
// "code": "DeploymentNotFound",
|
||||
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
|
||||
// }
|
||||
// }
|
||||
|
||||
// resource name is wrong (node will throw ENOTFOUND)
|
||||
|
||||
// rate limited (429)
|
||||
// TODO: try to reproduce this
|
||||
|
||||
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
|
||||
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
|
||||
const data = error.response.data;
|
||||
const status = data.error.status;
|
||||
const errorType = data.error.code || data.error.type;
|
||||
switch (errorType) {
|
||||
case "DeploymentNotFound":
|
||||
this.log.warn(
|
||||
{ key: key.hash, errorType, error: error.response.data },
|
||||
"Key is revoked or deployment ID is incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
});
|
||||
case "401":
|
||||
this.log.warn(
|
||||
{ key: key.hash, errorType, error: error.response.data },
|
||||
"Key is disabled or incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
});
|
||||
default:
|
||||
this.log.error(
|
||||
{ key: key.hash, errorType, error: error.response.data, status },
|
||||
"Unknown Azure API error while checking key. Please report this."
|
||||
);
|
||||
return this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
}
|
||||
|
||||
const { response, code } = error;
|
||||
if (code === "ENOTFOUND") {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: error.message },
|
||||
"Resource name is probably incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
}
|
||||
|
||||
const { headers, status, data } = response ?? {};
|
||||
this.log.error(
|
||||
{ key: key.hash, status, headers, data, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
private async testModel(key: AzureOpenAIKey) {
|
||||
const { apiKey, deploymentId, resourceName } =
|
||||
AzureOpenAIKeyChecker.getCredentialsFromKey(key);
|
||||
const url = POST_CHAT_COMPLETIONS(resourceName, deploymentId);
|
||||
const testRequest = {
|
||||
max_tokens: 1,
|
||||
stream: false,
|
||||
messages: [{ role: "user", content: "" }],
|
||||
};
|
||||
const { data } = await axios.post(url, testRequest, {
|
||||
headers: { "Content-Type": "application/json", "api-key": apiKey },
|
||||
});
|
||||
|
||||
return getAzureOpenAIModelFamily(data.model);
|
||||
}
|
||||
|
||||
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
|
||||
const data = error.response?.data as any;
|
||||
return data?.error?.code || data?.error?.type;
|
||||
}
|
||||
|
||||
static getCredentialsFromKey(key: AzureOpenAIKey) {
|
||||
const [resourceName, deploymentId, apiKey] = key.key.split(":");
|
||||
if (!resourceName || !deploymentId || !apiKey) {
|
||||
throw new Error(
|
||||
"Invalid Azure credential format. Refer to .env.example and ensure your credentials are in the format RESOURCE_NAME:DEPLOYMENT_ID:API_KEY with commas between each credential set."
|
||||
);
|
||||
}
|
||||
return { resourceName, deploymentId, apiKey };
|
||||
}
|
||||
}
|
|
@ -0,0 +1,212 @@
|
|||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AzureOpenAIModelFamily } from "../../models";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
import { OpenAIModel } from "../openai/provider";
|
||||
import { AzureOpenAIKeyChecker } from "./checker";
|
||||
import { AwsKeyChecker } from "../aws/checker";
|
||||
|
||||
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
|
||||
|
||||
type AzureOpenAIKeyUsage = {
|
||||
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
|
||||
readonly service: "azure";
|
||||
readonly modelFamilies: AzureOpenAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
contentFiltering: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 4000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 250;
|
||||
|
||||
export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
||||
readonly service = "azure";
|
||||
|
||||
private keys: AzureOpenAIKey[] = [];
|
||||
private checker?: AzureOpenAIKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.azureCredentials;
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"AZURE_CREDENTIALS is not set. Azure OpenAI API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: AzureOpenAIKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["azure-gpt4"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
contentFiltering: false,
|
||||
hash: `azu-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
"azure-turboTokens": 0,
|
||||
"azure-gpt4Tokens": 0,
|
||||
"azure-gpt4-32kTokens": 0,
|
||||
"azure-gpt4-turboTokens": 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Azure OpenAI keys.");
|
||||
}
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new AzureOpenAIKeyChecker(
|
||||
this.keys,
|
||||
this.update.bind(this)
|
||||
);
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(model: AzureOpenAIModel) {
|
||||
const neededFamily = getAzureOpenAIModelFamily(model);
|
||||
const availableKeys = this.keys.filter(
|
||||
(k) => !k.isDisabled && k.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error(`No keys available for model family '${neededFamily}'.`);
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: AzureOpenAIKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<AzureOpenAIKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
// TODO: all of this shit is duplicate code
|
||||
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
* concurrent requests running on this key. We don't have any information on
|
||||
* when these requests will resolve, so all we can do is wait a bit and try
|
||||
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||
}
|
||||
|
||||
public recheck() {
|
||||
this.keys.forEach(({ hash }) =>
|
||||
this.update(hash, { lastChecked: 0, isDisabled: false })
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ import { OpenAIModel } from "./openai/provider";
|
|||
import { AnthropicModel } from "./anthropic/provider";
|
||||
import { GooglePalmModel } from "./palm/provider";
|
||||
import { AwsBedrockModel } from "./aws/provider";
|
||||
import { AzureOpenAIModel } from "./azure/provider";
|
||||
import { KeyPool } from "./key-pool";
|
||||
import type { ModelFamily } from "../models";
|
||||
|
||||
|
@ -13,12 +14,18 @@ export type APIFormat =
|
|||
| "openai-text"
|
||||
| "openai-image";
|
||||
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
||||
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-palm"
|
||||
| "aws"
|
||||
| "azure";
|
||||
export type Model =
|
||||
| OpenAIModel
|
||||
| AnthropicModel
|
||||
| GooglePalmModel
|
||||
| AwsBedrockModel;
|
||||
| AwsBedrockModel
|
||||
| AzureOpenAIModel;
|
||||
|
||||
export interface Key {
|
||||
/** The API key itself. Never log this, use `hash` instead. */
|
||||
|
@ -72,3 +79,4 @@ export { AnthropicKey } from "./anthropic/provider";
|
|||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GooglePalmKey } from "./palm/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
|
|
|
@ -3,14 +3,17 @@ import { logger } from "../../logger";
|
|||
import { Key } from "./index";
|
||||
import { AxiosError } from "axios";
|
||||
|
||||
type KeyCheckerOptions = {
|
||||
type KeyCheckerOptions<TKey extends Key = Key> = {
|
||||
service: string;
|
||||
keyCheckPeriod: number;
|
||||
minCheckInterval: number;
|
||||
}
|
||||
recurringChecksEnabled?: boolean;
|
||||
updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||
};
|
||||
|
||||
export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
protected readonly service: string;
|
||||
protected readonly RECURRING_CHECKS_ENABLED: boolean;
|
||||
/** Minimum time in between any two key checks. */
|
||||
protected readonly MIN_CHECK_INTERVAL: number;
|
||||
/**
|
||||
|
@ -19,16 +22,19 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
|||
* than this.
|
||||
*/
|
||||
protected readonly KEY_CHECK_PERIOD: number;
|
||||
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||
protected readonly keys: TKey[] = [];
|
||||
protected log: pino.Logger;
|
||||
protected timeout?: NodeJS.Timeout;
|
||||
protected lastCheck = 0;
|
||||
|
||||
protected constructor(keys: TKey[], opts: KeyCheckerOptions) {
|
||||
protected constructor(keys: TKey[], opts: KeyCheckerOptions<TKey>) {
|
||||
const { service, keyCheckPeriod, minCheckInterval } = opts;
|
||||
this.keys = keys;
|
||||
this.KEY_CHECK_PERIOD = keyCheckPeriod;
|
||||
this.MIN_CHECK_INTERVAL = minCheckInterval;
|
||||
this.RECURRING_CHECKS_ENABLED = opts.recurringChecksEnabled ?? true;
|
||||
this.updateKey = opts.updateKey;
|
||||
this.service = service;
|
||||
this.log = logger.child({ module: "key-checker", service });
|
||||
}
|
||||
|
@ -52,31 +58,34 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
|||
* the minimum check interval.
|
||||
*/
|
||||
public scheduleNextCheck() {
|
||||
// Gives each concurrent check a correlation ID to make logs less confusing.
|
||||
const callId = Math.random().toString(36).slice(2, 8);
|
||||
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
|
||||
const checkLog = this.log.child({ callId, timeoutId });
|
||||
|
||||
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
|
||||
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
|
||||
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
|
||||
const numEnabled = enabledKeys.length;
|
||||
const numUnchecked = uncheckedKeys.length;
|
||||
|
||||
clearTimeout(this.timeout);
|
||||
this.timeout = undefined;
|
||||
|
||||
if (enabledKeys.length === 0) {
|
||||
checkLog.warn("All keys are disabled. Key checker stopping.");
|
||||
if (!numEnabled) {
|
||||
checkLog.warn("All keys are disabled. Stopping.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Perform startup checks for any keys that haven't been checked yet.
|
||||
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
|
||||
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
|
||||
if (uncheckedKeys.length > 0) {
|
||||
const keysToCheck = uncheckedKeys.slice(0, 12);
|
||||
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
|
||||
|
||||
if (numUnchecked > 0) {
|
||||
const keycheckBatch = uncheckedKeys.slice(0, 12);
|
||||
|
||||
this.timeout = setTimeout(async () => {
|
||||
try {
|
||||
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
|
||||
await Promise.all(keycheckBatch.map((key) => this.checkKey(key)));
|
||||
} catch (error) {
|
||||
this.log.error({ error }, "Error checking one or more keys.");
|
||||
checkLog.error({ error }, "Error checking one or more keys.");
|
||||
}
|
||||
checkLog.info("Batch complete.");
|
||||
this.scheduleNextCheck();
|
||||
|
@ -84,11 +93,18 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
|||
|
||||
checkLog.info(
|
||||
{
|
||||
batch: keysToCheck.map((k) => k.hash),
|
||||
remaining: uncheckedKeys.length - keysToCheck.length,
|
||||
batch: keycheckBatch.map((k) => k.hash),
|
||||
remaining: uncheckedKeys.length - keycheckBatch.length,
|
||||
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
|
||||
},
|
||||
"Scheduled batch check."
|
||||
"Scheduled batch of initial checks."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.RECURRING_CHECKS_ENABLED) {
|
||||
checkLog.info(
|
||||
"Initial checks complete and recurring checks are disabled for this service. Stopping."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
@ -106,14 +122,35 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
|||
);
|
||||
|
||||
const delay = nextCheck - Date.now();
|
||||
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
|
||||
this.timeout = setTimeout(
|
||||
() => this.checkKey(oldestKey).then(() => this.scheduleNextCheck()),
|
||||
delay
|
||||
);
|
||||
checkLog.debug(
|
||||
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
|
||||
"Scheduled single key check."
|
||||
"Scheduled next recurring check."
|
||||
);
|
||||
}
|
||||
|
||||
protected abstract checkKey(key: TKey): Promise<void>;
|
||||
public async checkKey(key: TKey): Promise<void> {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
|
||||
try {
|
||||
await this.testKeyOrFail(key);
|
||||
} catch (error) {
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
}
|
||||
|
||||
protected abstract testKeyOrFail(key: TKey): Promise<void>;
|
||||
|
||||
protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import { GooglePalmKeyProvider } from "./palm/provider";
|
|||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { ModelFamily } from "../models";
|
||||
import { assertNever } from "../utils";
|
||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||
|
||||
|
@ -25,6 +26,7 @@ export class KeyPool {
|
|||
this.keyProviders.push(new AnthropicKeyProvider());
|
||||
this.keyProviders.push(new GooglePalmKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||
}
|
||||
|
||||
public init() {
|
||||
|
@ -124,6 +126,8 @@ export class KeyPool {
|
|||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
return "aws";
|
||||
} else if (model.startsWith("azure")) {
|
||||
return "azure";
|
||||
}
|
||||
throw new Error(`Unknown service for model '${model}'`);
|
||||
}
|
||||
|
@ -142,6 +146,11 @@ export class KeyPool {
|
|||
return "google-palm";
|
||||
case "aws-claude":
|
||||
return "aws";
|
||||
case "azure-turbo":
|
||||
case "azure-gpt4":
|
||||
case "azure-gpt4-32k":
|
||||
case "azure-gpt4-turbo":
|
||||
return "azure";
|
||||
default:
|
||||
assertNever(modelFamily);
|
||||
}
|
||||
|
|
|
@ -27,65 +27,41 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
|
|||
|
||||
export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
||||
private readonly cloneKey: CloneFn;
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "openai",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
this.cloneKey = cloneFn;
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async checkKey(key: OpenAIKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
if (isInitialCheck) {
|
||||
const [provisionedModels, livenessTest] = await Promise.all([
|
||||
this.getProvisionedModels(key),
|
||||
this.testLiveness(key),
|
||||
this.maybeCreateOrganizationClones(key),
|
||||
]);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
isTrial: livenessTest.rateLimit <= 250,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
} else {
|
||||
// No updates needed as models and trial status generally don't change.
|
||||
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
||||
this.updateKey(key.hash, {});
|
||||
}
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
// touch the key so we don't check it again for a while
|
||||
protected async testKeyOrFail(key: OpenAIKey) {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
const [provisionedModels, livenessTest] = await Promise.all([
|
||||
this.getProvisionedModels(key),
|
||||
this.testLiveness(key),
|
||||
this.maybeCreateOrganizationClones(key),
|
||||
]);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
isTrial: livenessTest.rateLimit <= 250,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
} else {
|
||||
// No updates needed as models and trial status generally don't change.
|
||||
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.log.info(
|
||||
{ key: key.hash },
|
||||
"Recurring keychecks are disabled, no-op."
|
||||
);
|
||||
// this.scheduleNextCheck();
|
||||
}
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
private async getProvisionedModels(
|
||||
|
@ -138,6 +114,17 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
|||
.filter(({ is_default }) => !is_default)
|
||||
.map(({ id }) => id);
|
||||
this.cloneKey(key.hash, ids);
|
||||
|
||||
// It's possible that the keychecker may be stopped if all non-cloned keys
|
||||
// happened to be unusable, in which case this clnoe will never be checked
|
||||
// unless we restart the keychecker.
|
||||
if (!this.timeout) {
|
||||
this.log.warn(
|
||||
{ parent: key.hash },
|
||||
"Restarting key checker to check cloned keys."
|
||||
);
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: OpenAIKey, error: AxiosError) {
|
||||
|
|
|
@ -217,17 +217,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
|||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
// logger.debug(
|
||||
// {
|
||||
// byPriority: keysByPriority.map((k) => ({
|
||||
// hash: k.hash,
|
||||
// isRateLimited: now - k.rateLimitedAt < rateLimitThreshold,
|
||||
// modelFamilies: k.modelFamilies,
|
||||
// })),
|
||||
// },
|
||||
// "Keys sorted by priority"
|
||||
// );
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
|
|
|
@ -2,15 +2,25 @@
|
|||
|
||||
import pino from "pino";
|
||||
|
||||
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e";
|
||||
export type OpenAIModelFamily =
|
||||
| "turbo"
|
||||
| "gpt4"
|
||||
| "gpt4-32k"
|
||||
| "gpt4-turbo"
|
||||
| "dall-e";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type GooglePalmModelFamily = "bison";
|
||||
export type AwsBedrockModelFamily = "aws-claude";
|
||||
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
OpenAIModelFamily,
|
||||
"dall-e"
|
||||
>}`;
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GooglePalmModelFamily
|
||||
| AwsBedrockModelFamily;
|
||||
| AwsBedrockModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
|
||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||
|
@ -23,6 +33,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
|||
"claude",
|
||||
"bison",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
"azure-gpt4-turbo",
|
||||
] as const);
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
|
@ -64,6 +78,24 @@ export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
|||
return "aws-claude";
|
||||
}
|
||||
|
||||
export function getAzureOpenAIModelFamily(
|
||||
model: string,
|
||||
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
|
||||
): AzureOpenAIModelFamily {
|
||||
// Azure model names omit periods. addAzureKey also prepends "azure-" to the
|
||||
// model name to route the request the correct keyprovider, so we need to
|
||||
// remove that as well.
|
||||
const modified = model
|
||||
.replace("gpt-35-turbo", "gpt-3.5-turbo")
|
||||
.replace("azure-", "");
|
||||
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
|
||||
if (modified.match(regex)) {
|
||||
return `azure-${family}` as AzureOpenAIModelFamily;
|
||||
}
|
||||
}
|
||||
return defaultFamily;
|
||||
}
|
||||
|
||||
export function assertIsKnownModelFamily(
|
||||
modelFamily: string
|
||||
): asserts modelFamily is ModelFamily {
|
||||
|
|
|
@ -12,6 +12,7 @@ import schedule from "node-schedule";
|
|||
import { v4 as uuid } from "uuid";
|
||||
import { config, getFirebaseApp } from "../../config";
|
||||
import {
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGooglePalmModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
|
@ -34,6 +35,10 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
|||
claude: 0,
|
||||
bison: 0,
|
||||
"aws-claude": 0,
|
||||
"azure-turbo": 0,
|
||||
"azure-gpt4": 0,
|
||||
"azure-gpt4-turbo": 0,
|
||||
"azure-gpt4-32k": 0,
|
||||
};
|
||||
|
||||
const users: Map<string, User> = new Map();
|
||||
|
@ -382,6 +387,9 @@ function getModelFamilyForQuotaUsage(
|
|||
model: string,
|
||||
api: APIFormat
|
||||
): ModelFamily {
|
||||
// TODO: this seems incorrect
|
||||
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
|
||||
|
||||
switch (api) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
|
|
Loading…
Reference in New Issue