Azure OpenAI suport (khanon/oai-reverse-proxy!48)
This commit is contained in:
parent
cd1b9d0e0c
commit
fbdea30264
|
@ -34,10 +34,10 @@
|
||||||
|
|
||||||
# Which model types users are allowed to access.
|
# Which model types users are allowed to access.
|
||||||
# The following model families are recognized:
|
# The following model families are recognized:
|
||||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude
|
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||||
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
||||||
# generation, uncomment the line below and add 'dall-e' to the list.
|
# generation, uncomment the line below and add 'dall-e' to the list.
|
||||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude
|
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||||
|
|
||||||
# URLs from which requests will be blocked.
|
# URLs from which requests will be blocked.
|
||||||
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
||||||
|
@ -114,6 +114,8 @@ OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||||
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||||
# See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS.
|
# See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS.
|
||||||
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
||||||
|
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
|
||||||
|
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
|
||||||
|
|
||||||
# With proxy_key gatekeeper, the password users must provide to access the API.
|
# With proxy_key gatekeeper, the password users must provide to access the API.
|
||||||
# PROXY_KEY=your-secret-key
|
# PROXY_KEY=your-secret-key
|
||||||
|
|
|
@ -5,3 +5,4 @@
|
||||||
build
|
build
|
||||||
greeting.md
|
greeting.md
|
||||||
node_modules
|
node_modules
|
||||||
|
http-client.private.env.json
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Configuring the proxy for Azure
|
||||||
|
|
||||||
|
The proxy supports Azure OpenAI Service via the `/proxy/azure/openai` endpoint. The process of setting it up is slightly different from regular OpenAI.
|
||||||
|
|
||||||
|
- [Setting keys](#setting-keys)
|
||||||
|
- [Model assignment](#model-assignment)
|
||||||
|
|
||||||
|
## Setting keys
|
||||||
|
|
||||||
|
Use the `AZURE_CREDENTIALS` environment variable to set the Azure API keys.
|
||||||
|
|
||||||
|
Like other APIs, you can provide multiple keys separated by commas. Each Azure key, however, is a set of values including the Resource Name, Deployment ID, and API key. These are separated by a colon (`:`).
|
||||||
|
|
||||||
|
For example:
|
||||||
|
```
|
||||||
|
AZURE_CREDENTIALS=contoso-ml:gpt4-8k:0123456789abcdef0123456789abcdef,northwind-corp:testdeployment:0123456789abcdef0123456789abcdef
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model assignment
|
||||||
|
Note that each Azure deployment is assigned a model when you create it in the Microsoft Cognitive Services portal. If you want to use a different model, you'll need to create a new deployment, and therefore a new key to be added to the AZURE_CREDENTIALS environment variable. Each credential only grants access to one model.
|
||||||
|
|
||||||
|
### Supported model IDs
|
||||||
|
Users can send normal OpenAI model IDs to the proxy to invoke the corresponding models. For the most part they work the same with Azure. GPT-3.5 Turbo has an ID of "gpt-35-turbo" because Azure doesn't allow periods in model names, but the proxy should automatically convert this to the correct ID.
|
||||||
|
|
||||||
|
As noted above, you can only use model IDs for which a deployment has been created and added to the proxy.
|
|
@ -0,0 +1,9 @@
|
||||||
|
{
|
||||||
|
"dev": {
|
||||||
|
"proxy-host": "http://localhost:7860",
|
||||||
|
"oai-key-1": "override in http-client.private.env.json",
|
||||||
|
"proxy-key": "override in http-client.private.env.json",
|
||||||
|
"azu-resource-name": "override in http-client.private.env.json",
|
||||||
|
"azu-deployment-id": "override in http-client.private.env.json"
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,248 @@
|
||||||
|
# OAI Reverse Proxy
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name OpenAI -- Chat Completions
|
||||||
|
POST https://api.openai.com/v1/chat/completions
|
||||||
|
Authorization: Bearer {{oai-key-1}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"max_tokens": 30,
|
||||||
|
"stream": false,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "This is a test prompt."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name OpenAI -- Text Completions
|
||||||
|
POST https://api.openai.com/v1/completions
|
||||||
|
Authorization: Bearer {{oai-key-1}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo-instruct",
|
||||||
|
"max_tokens": 30,
|
||||||
|
"stream": false,
|
||||||
|
"prompt": "This is a test prompt where"
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name OpenAI -- Create Embedding
|
||||||
|
POST https://api.openai.com/v1/embeddings
|
||||||
|
Authorization: Bearer {{oai-key-1}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"input": "This is a test embedding input."
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name OpenAI -- Get Organizations
|
||||||
|
GET https://api.openai.com/v1/organizations
|
||||||
|
Authorization: Bearer {{oai-key-1}}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name OpenAI -- Get Models
|
||||||
|
GET https://api.openai.com/v1/models
|
||||||
|
Authorization: Bearer {{oai-key-1}}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Azure OpenAI -- Chat Completions
|
||||||
|
POST https://{{azu-resource-name}}.openai.azure.com/openai/deployments/{{azu-deployment-id}}/chat/completions?api-version=2023-09-01-preview
|
||||||
|
api-key: {{azu-key-1}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"max_tokens": 1,
|
||||||
|
"stream": false,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "This is a test prompt."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / OpenAI -- Get Models
|
||||||
|
GET {{proxy-host}}/proxy/openai/v1/models
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / OpenAI -- Native Chat Completions
|
||||||
|
POST {{proxy-host}}/proxy/openai/chat/completions
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"max_tokens": 20,
|
||||||
|
"stream": true,
|
||||||
|
"temperature": 1,
|
||||||
|
"seed": 123,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "phrase one"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / OpenAI -- Native Text Completions
|
||||||
|
POST {{proxy-host}}/proxy/openai/v1/turbo-instruct/chat/completions
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo-instruct",
|
||||||
|
"max_tokens": 20,
|
||||||
|
"temperature": 0,
|
||||||
|
"prompt": "Genshin Impact is a game about",
|
||||||
|
"stream": false
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / OpenAI -- Chat-to-Text API Translation
|
||||||
|
# Accepts a chat completion request and reformats it to work with the text completion API. `model` is ignored.
|
||||||
|
POST {{proxy-host}}/proxy/openai/turbo-instruct/chat/completions
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"max_tokens": 20,
|
||||||
|
"stream": true,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the name of the fourth president of the united states?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "That would be George Washington."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I don't think that's right..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / OpenAI -- Create Embedding
|
||||||
|
POST {{proxy-host}}/proxy/openai/embeddings
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"input": "This is a test embedding input."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / Anthropic -- Native Completion (old API)
|
||||||
|
POST {{proxy-host}}/proxy/anthropic/v1/complete
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
anthropic-version: 2023-01-01
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "claude-v1.3",
|
||||||
|
"max_tokens_to_sample": 20,
|
||||||
|
"temperature": 0.2,
|
||||||
|
"stream": true,
|
||||||
|
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / Anthropic -- Native Completion (2023-06-01 API)
|
||||||
|
POST {{proxy-host}}/proxy/anthropic/v1/complete
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
anthropic-version: 2023-06-01
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "claude-v1.3",
|
||||||
|
"max_tokens_to_sample": 20,
|
||||||
|
"temperature": 0.2,
|
||||||
|
"stream": true,
|
||||||
|
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / Anthropic -- OpenAI-to-Anthropic API Translation
|
||||||
|
POST {{proxy-host}}/proxy/anthropic/v1/chat/completions
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
#anthropic-version: 2023-06-01
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"max_tokens": 20,
|
||||||
|
"stream": false,
|
||||||
|
"temperature": 0,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is genshin impact"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / AWS Claude -- Native Completion
|
||||||
|
POST {{proxy-host}}/proxy/aws/claude/v1/complete
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
anthropic-version: 2023-01-01
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "claude-v2",
|
||||||
|
"max_tokens_to_sample": 10,
|
||||||
|
"temperature": 0,
|
||||||
|
"stream": true,
|
||||||
|
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / AWS Claude -- OpenAI-to-Anthropic API Translation
|
||||||
|
POST {{proxy-host}}/proxy/aws/claude/chat/completions
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stream": true,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is genshin impact?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation
|
||||||
|
POST {{proxy-host}}/proxy/google-palm/v1/chat/completions
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"max_tokens": 42,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi what is the name of the fourth president of the united states?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
$NumThreads = 10
|
||||||
|
|
||||||
|
$runspacePool = [runspacefactory]::CreateRunspacePool(1, $NumThreads)
|
||||||
|
$runspacePool.Open()
|
||||||
|
$runspaces = @()
|
||||||
|
|
||||||
|
$headers = @{
|
||||||
|
"Authorization" = "Bearer test"
|
||||||
|
"anthropic-version" = "2023-01-01"
|
||||||
|
"Content-Type" = "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
$payload = @{
|
||||||
|
model = "claude-v2"
|
||||||
|
max_tokens_to_sample = 40
|
||||||
|
temperature = 0
|
||||||
|
stream = $true
|
||||||
|
prompt = "Test prompt, please reply with lorem ipsum`n`n:Assistant:"
|
||||||
|
} | ConvertTo-Json
|
||||||
|
|
||||||
|
for ($i = 1; $i -le $NumThreads; $i++) {
|
||||||
|
Write-Host "Starting thread $i"
|
||||||
|
$runspace = [powershell]::Create()
|
||||||
|
$runspace.AddScript({
|
||||||
|
param($i, $headers, $payload)
|
||||||
|
$response = Invoke-WebRequest -Uri "http://localhost:7860/proxy/aws/claude/v1/complete" -Method Post -Headers $headers -Body $payload
|
||||||
|
Write-Host "Response from server: $($response.StatusCode)"
|
||||||
|
}).AddArgument($i).AddArgument($headers).AddArgument($payload)
|
||||||
|
|
||||||
|
$runspace.RunspacePool = $runspacePool
|
||||||
|
$runspaces += [PSCustomObject]@{ Pipe = $runspace; Status = $runspace.BeginInvoke() }
|
||||||
|
}
|
||||||
|
|
||||||
|
$runspaces | ForEach-Object {
|
||||||
|
$_.Pipe.EndInvoke($_.Status)
|
||||||
|
$_.Pipe.Dispose()
|
||||||
|
}
|
||||||
|
|
||||||
|
$runspacePool.Close()
|
||||||
|
$runspacePool.Dispose()
|
|
@ -33,6 +33,17 @@ type Config = {
|
||||||
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
|
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
|
||||||
*/
|
*/
|
||||||
awsCredentials?: string;
|
awsCredentials?: string;
|
||||||
|
/**
|
||||||
|
* Comma-delimited list of Azure OpenAI credentials. Each credential item
|
||||||
|
* should be a colon-delimited list of Azure resource name, deployment ID, and
|
||||||
|
* API key.
|
||||||
|
*
|
||||||
|
* The resource name is the subdomain in your Azure OpenAI deployment's URL,
|
||||||
|
* e.g. `https://resource-name.openai.azure.com
|
||||||
|
*
|
||||||
|
* @example `AZURE_CREDENTIALS=resource_name_1:deployment_id_1:api_key_1,resource_name_2:deployment_id_2:api_key_2`
|
||||||
|
*/
|
||||||
|
azureCredentials?: string;
|
||||||
/**
|
/**
|
||||||
* The proxy key to require for requests. Only applicable if the user
|
* The proxy key to require for requests. Only applicable if the user
|
||||||
* management mode is set to 'proxy_key', and required if so.
|
* management mode is set to 'proxy_key', and required if so.
|
||||||
|
@ -188,6 +199,7 @@ export const config: Config = {
|
||||||
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
||||||
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
|
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
|
||||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||||
|
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||||
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
||||||
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
|
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
|
||||||
|
@ -219,6 +231,10 @@ export const config: Config = {
|
||||||
"claude",
|
"claude",
|
||||||
"bison",
|
"bison",
|
||||||
"aws-claude",
|
"aws-claude",
|
||||||
|
"azure-turbo",
|
||||||
|
"azure-gpt4",
|
||||||
|
"azure-gpt4-turbo",
|
||||||
|
"azure-gpt4-32k",
|
||||||
]),
|
]),
|
||||||
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
|
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
|
||||||
rejectMessage: getEnvWithDefault(
|
rejectMessage: getEnvWithDefault(
|
||||||
|
@ -352,6 +368,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
||||||
"anthropicKey",
|
"anthropicKey",
|
||||||
"googlePalmKey",
|
"googlePalmKey",
|
||||||
"awsCredentials",
|
"awsCredentials",
|
||||||
|
"azureCredentials",
|
||||||
"proxyKey",
|
"proxyKey",
|
||||||
"adminKey",
|
"adminKey",
|
||||||
"rejectPhrases",
|
"rejectPhrases",
|
||||||
|
@ -369,6 +386,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
||||||
"useInsecureCookies",
|
"useInsecureCookies",
|
||||||
"staticServiceInfo",
|
"staticServiceInfo",
|
||||||
"checkKeys",
|
"checkKeys",
|
||||||
|
"allowedModelFamilies",
|
||||||
];
|
];
|
||||||
|
|
||||||
const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>;
|
const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>;
|
||||||
|
@ -417,6 +435,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
||||||
"ANTHROPIC_KEY",
|
"ANTHROPIC_KEY",
|
||||||
"GOOGLE_PALM_KEY",
|
"GOOGLE_PALM_KEY",
|
||||||
"AWS_CREDENTIALS",
|
"AWS_CREDENTIALS",
|
||||||
|
"AZURE_CREDENTIALS",
|
||||||
].includes(String(env))
|
].includes(String(env))
|
||||||
) {
|
) {
|
||||||
return value as unknown as T;
|
return value as unknown as T;
|
||||||
|
|
137
src/info-page.ts
137
src/info-page.ts
|
@ -1,3 +1,4 @@
|
||||||
|
/** This whole module really sucks */
|
||||||
import fs from "fs";
|
import fs from "fs";
|
||||||
import { Request, Response } from "express";
|
import { Request, Response } from "express";
|
||||||
import showdown from "showdown";
|
import showdown from "showdown";
|
||||||
|
@ -5,11 +6,16 @@ import { config, listConfig } from "./config";
|
||||||
import {
|
import {
|
||||||
AnthropicKey,
|
AnthropicKey,
|
||||||
AwsBedrockKey,
|
AwsBedrockKey,
|
||||||
|
AzureOpenAIKey,
|
||||||
GooglePalmKey,
|
GooglePalmKey,
|
||||||
keyPool,
|
keyPool,
|
||||||
OpenAIKey,
|
OpenAIKey,
|
||||||
} from "./shared/key-management";
|
} from "./shared/key-management";
|
||||||
import { ModelFamily, OpenAIModelFamily } from "./shared/models";
|
import {
|
||||||
|
AzureOpenAIModelFamily,
|
||||||
|
ModelFamily,
|
||||||
|
OpenAIModelFamily,
|
||||||
|
} from "./shared/models";
|
||||||
import { getUniqueIps } from "./proxy/rate-limit";
|
import { getUniqueIps } from "./proxy/rate-limit";
|
||||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||||
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
|
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
|
||||||
|
@ -23,6 +29,8 @@ let infoPageLastUpdated = 0;
|
||||||
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
|
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
|
||||||
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
|
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
|
||||||
k.service === "openai";
|
k.service === "openai";
|
||||||
|
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
|
||||||
|
k.service === "azure";
|
||||||
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||||
k.service === "anthropic";
|
k.service === "anthropic";
|
||||||
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
|
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
|
||||||
|
@ -48,6 +56,7 @@ type ServiceAggregates = {
|
||||||
anthropicKeys?: number;
|
anthropicKeys?: number;
|
||||||
palmKeys?: number;
|
palmKeys?: number;
|
||||||
awsKeys?: number;
|
awsKeys?: number;
|
||||||
|
azureKeys?: number;
|
||||||
proompts: number;
|
proompts: number;
|
||||||
tokens: number;
|
tokens: number;
|
||||||
tokenCost: number;
|
tokenCost: number;
|
||||||
|
@ -62,17 +71,15 @@ const serviceStats = new Map<keyof ServiceAggregates, number>();
|
||||||
|
|
||||||
export const handleInfoPage = (req: Request, res: Response) => {
|
export const handleInfoPage = (req: Request, res: Response) => {
|
||||||
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
|
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
|
||||||
res.send(infoPageHtml);
|
return res.send(infoPageHtml);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sometimes huggingface doesn't send the host header and makes us guess.
|
|
||||||
const baseUrl =
|
const baseUrl =
|
||||||
process.env.SPACE_ID && !req.get("host")?.includes("hf.space")
|
process.env.SPACE_ID && !req.get("host")?.includes("hf.space")
|
||||||
? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID)
|
? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID)
|
||||||
: req.protocol + "://" + req.get("host");
|
: req.protocol + "://" + req.get("host");
|
||||||
|
|
||||||
infoPageHtml = buildInfoPageHtml(baseUrl);
|
infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy");
|
||||||
infoPageLastUpdated = Date.now();
|
infoPageLastUpdated = Date.now();
|
||||||
|
|
||||||
res.send(infoPageHtml);
|
res.send(infoPageHtml);
|
||||||
|
@ -95,6 +102,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||||
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
|
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
|
||||||
const palmKeys = serviceStats.get("palmKeys") || 0;
|
const palmKeys = serviceStats.get("palmKeys") || 0;
|
||||||
const awsKeys = serviceStats.get("awsKeys") || 0;
|
const awsKeys = serviceStats.get("awsKeys") || 0;
|
||||||
|
const azureKeys = serviceStats.get("azureKeys") || 0;
|
||||||
const proompts = serviceStats.get("proompts") || 0;
|
const proompts = serviceStats.get("proompts") || 0;
|
||||||
const tokens = serviceStats.get("tokens") || 0;
|
const tokens = serviceStats.get("tokens") || 0;
|
||||||
const tokenCost = serviceStats.get("tokenCost") || 0;
|
const tokenCost = serviceStats.get("tokenCost") || 0;
|
||||||
|
@ -102,16 +110,15 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||||
const allowDalle = config.allowedModelFamilies.includes("dall-e");
|
const allowDalle = config.allowedModelFamilies.includes("dall-e");
|
||||||
|
|
||||||
const endpoints = {
|
const endpoints = {
|
||||||
...(openaiKeys ? { openai: baseUrl + "/proxy/openai" } : {}),
|
...(openaiKeys ? { openai: baseUrl + "/openai" } : {}),
|
||||||
...(openaiKeys
|
...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}),
|
||||||
? { ["openai2"]: baseUrl + "/proxy/openai/turbo-instruct" }
|
|
||||||
: {}),
|
|
||||||
...(openaiKeys && allowDalle
|
...(openaiKeys && allowDalle
|
||||||
? { ["openai-image"]: baseUrl + "/proxy/openai-image" }
|
? { ["openai-image"]: baseUrl + "/openai-image" }
|
||||||
: {}),
|
: {}),
|
||||||
...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}),
|
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
|
||||||
...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}),
|
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
|
||||||
...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}),
|
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
|
||||||
|
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
|
||||||
};
|
};
|
||||||
|
|
||||||
const stats = {
|
const stats = {
|
||||||
|
@ -120,13 +127,17 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||||
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
||||||
};
|
};
|
||||||
|
|
||||||
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys };
|
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
|
||||||
|
for (const key of Object.keys(keyInfo)) {
|
||||||
|
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
|
||||||
|
}
|
||||||
|
|
||||||
const providerInfo = {
|
const providerInfo = {
|
||||||
...(openaiKeys ? getOpenAIInfo() : {}),
|
...(openaiKeys ? getOpenAIInfo() : {}),
|
||||||
...(anthropicKeys ? getAnthropicInfo() : {}),
|
...(anthropicKeys ? getAnthropicInfo() : {}),
|
||||||
...(palmKeys ? { "palm-bison": getPalmInfo() } : {}),
|
...(palmKeys ? getPalmInfo() : {}),
|
||||||
...(awsKeys ? { "aws-claude": getAwsInfo() } : {}),
|
...(awsKeys ? getAwsInfo() : {}),
|
||||||
|
...(azureKeys ? getAzureInfo() : {}),
|
||||||
};
|
};
|
||||||
|
|
||||||
if (hideFullInfo) {
|
if (hideFullInfo) {
|
||||||
|
@ -188,6 +199,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||||
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
|
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
|
||||||
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
|
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
|
||||||
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
|
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
|
||||||
|
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
|
||||||
|
|
||||||
let sumTokens = 0;
|
let sumTokens = 0;
|
||||||
let sumCost = 0;
|
let sumCost = 0;
|
||||||
|
@ -201,17 +213,26 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||||
Boolean(k.lastChecked) ? 0 : 1
|
Boolean(k.lastChecked) ? 0 : 1
|
||||||
);
|
);
|
||||||
|
|
||||||
// Technically this would not account for keys that have tokens recorded
|
|
||||||
// on models they aren't provisioned for, but that would be strange
|
|
||||||
k.modelFamilies.forEach((f) => {
|
k.modelFamilies.forEach((f) => {
|
||||||
const tokens = k[`${f}Tokens`];
|
const tokens = k[`${f}Tokens`];
|
||||||
sumTokens += tokens;
|
sumTokens += tokens;
|
||||||
sumCost += getTokenCostUsd(f, tokens);
|
sumCost += getTokenCostUsd(f, tokens);
|
||||||
increment(modelStats, `${f}__tokens`, tokens);
|
increment(modelStats, `${f}__tokens`, tokens);
|
||||||
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
|
||||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
|
||||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||||
|
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
||||||
|
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case "azure":
|
||||||
|
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
|
||||||
|
k.modelFamilies.forEach((f) => {
|
||||||
|
const tokens = k[`${f}Tokens`];
|
||||||
|
sumTokens += tokens;
|
||||||
|
sumCost += getTokenCostUsd(f, tokens);
|
||||||
|
increment(modelStats, `${f}__tokens`, tokens);
|
||||||
|
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||||
|
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
case "anthropic": {
|
case "anthropic": {
|
||||||
|
@ -381,11 +402,13 @@ function getPalmInfo() {
|
||||||
const cost = getTokenCostUsd("bison", tokens);
|
const cost = getTokenCostUsd("bison", tokens);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
bison: {
|
||||||
activeKeys: bisonInfo.active,
|
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||||
revokedKeys: bisonInfo.revoked,
|
activeKeys: bisonInfo.active,
|
||||||
proomptersInQueue: bisonInfo.queued,
|
revokedKeys: bisonInfo.revoked,
|
||||||
estimatedQueueTime: bisonInfo.queueTime,
|
proomptersInQueue: bisonInfo.queued,
|
||||||
|
estimatedQueueTime: bisonInfo.queueTime,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -408,15 +431,59 @@ function getAwsInfo() {
|
||||||
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
|
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
"aws-claude": {
|
||||||
activeKeys: awsInfo.active,
|
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||||
revokedKeys: awsInfo.revoked,
|
activeKeys: awsInfo.active,
|
||||||
proomptersInQueue: awsInfo.queued,
|
revokedKeys: awsInfo.revoked,
|
||||||
estimatedQueueTime: awsInfo.queueTime,
|
proomptersInQueue: awsInfo.queued,
|
||||||
...(logged > 0 ? { privacy: logMsg } : {}),
|
estimatedQueueTime: awsInfo.queueTime,
|
||||||
|
...(logged > 0 ? { privacy: logMsg } : {}),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getAzureInfo() {
|
||||||
|
const azureFamilies = [
|
||||||
|
"azure-turbo",
|
||||||
|
"azure-gpt4",
|
||||||
|
"azure-gpt4-turbo",
|
||||||
|
"azure-gpt4-32k",
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
const azureInfo: {
|
||||||
|
[modelFamily in AzureOpenAIModelFamily]?: {
|
||||||
|
usage?: string;
|
||||||
|
activeKeys: number;
|
||||||
|
revokedKeys?: number;
|
||||||
|
proomptersInQueue?: number;
|
||||||
|
estimatedQueueTime?: string;
|
||||||
|
};
|
||||||
|
} = {};
|
||||||
|
for (const family of azureFamilies) {
|
||||||
|
const familyAllowed = config.allowedModelFamilies.includes(family);
|
||||||
|
const activeKeys = modelStats.get(`${family}__active`) || 0;
|
||||||
|
|
||||||
|
if (!familyAllowed || activeKeys === 0) continue;
|
||||||
|
|
||||||
|
azureInfo[family] = {
|
||||||
|
activeKeys,
|
||||||
|
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
const queue = getQueueInformation(family);
|
||||||
|
azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue;
|
||||||
|
azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime;
|
||||||
|
|
||||||
|
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||||
|
const cost = getTokenCostUsd(family, tokens);
|
||||||
|
azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString(
|
||||||
|
cost
|
||||||
|
)}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return azureInfo;
|
||||||
|
}
|
||||||
|
|
||||||
const customGreeting = fs.existsSync("greeting.md")
|
const customGreeting = fs.existsSync("greeting.md")
|
||||||
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
|
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
|
||||||
: "";
|
: "";
|
||||||
|
@ -430,10 +497,10 @@ function buildInfoPageHeader(converter: showdown.Converter, title: string) {
|
||||||
let infoBody = `<!-- Header for Showdown's parser, don't remove this line -->
|
let infoBody = `<!-- Header for Showdown's parser, don't remove this line -->
|
||||||
# ${title}`;
|
# ${title}`;
|
||||||
if (config.promptLogging) {
|
if (config.promptLogging) {
|
||||||
infoBody += `\n## Prompt logging is enabled!
|
infoBody += `\n## Prompt Logging Enabled
|
||||||
The server operator has enabled prompt logging. The prompts you send to this proxy and the AI responses you receive may be saved.
|
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
|
||||||
|
|
||||||
Logs are anonymous and do not contain IP addresses or timestamps. [You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/prompt-logging/index.ts).
|
[You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/shared/prompt-logging/index.ts).
|
||||||
|
|
||||||
**If you are uncomfortable with this, don't send prompts to this proxy!**`;
|
**If you are uncomfortable with this, don't send prompts to this proxy!**`;
|
||||||
}
|
}
|
||||||
|
@ -570,8 +637,6 @@ function escapeHtml(unsafe: string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
|
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
|
||||||
// Huggingface broke their amazon elb config and no longer sends the
|
|
||||||
// x-forwarded-host header. This is a workaround.
|
|
||||||
try {
|
try {
|
||||||
const [username, spacename] = spaceId.split("/");
|
const [username, spacename] = spaceId.split("/");
|
||||||
return `https://${username}-${spacename.replace(/_/g, "-")}.hf.space`;
|
return `https://${username}-${spacename.replace(/_/g, "-")}.hf.space`;
|
||||||
|
|
|
@ -11,7 +11,7 @@ import {
|
||||||
createPreprocessorMiddleware,
|
createPreprocessorMiddleware,
|
||||||
stripHeaders,
|
stripHeaders,
|
||||||
signAwsRequest,
|
signAwsRequest,
|
||||||
finalizeAwsRequest,
|
finalizeSignedRequest,
|
||||||
createOnProxyReqHandler,
|
createOnProxyReqHandler,
|
||||||
blockZoomerOrigins,
|
blockZoomerOrigins,
|
||||||
} from "./middleware/request";
|
} from "./middleware/request";
|
||||||
|
@ -30,7 +30,11 @@ const getModelsResponse = () => {
|
||||||
|
|
||||||
if (!config.awsCredentials) return { object: "list", data: [] };
|
if (!config.awsCredentials) return { object: "list", data: [] };
|
||||||
|
|
||||||
const variants = ["anthropic.claude-v1", "anthropic.claude-v2"];
|
const variants = [
|
||||||
|
"anthropic.claude-v1",
|
||||||
|
"anthropic.claude-v2",
|
||||||
|
"anthropic.claude-v2:1",
|
||||||
|
];
|
||||||
|
|
||||||
const models = variants.map((id) => ({
|
const models = variants.map((id) => ({
|
||||||
id,
|
id,
|
||||||
|
@ -134,7 +138,7 @@ const awsProxy = createQueueMiddleware({
|
||||||
applyQuotaLimits,
|
applyQuotaLimits,
|
||||||
blockZoomerOrigins,
|
blockZoomerOrigins,
|
||||||
stripHeaders,
|
stripHeaders,
|
||||||
finalizeAwsRequest,
|
finalizeSignedRequest,
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||||
|
@ -183,7 +187,7 @@ function maybeReassignModel(req: Request) {
|
||||||
req.body.model = "anthropic.claude-v1";
|
req.body.model = "anthropic.claude-v1";
|
||||||
} else {
|
} else {
|
||||||
// User's client requested v2 or possibly some OpenAI model, default to v2
|
// User's client requested v2 or possibly some OpenAI model, default to v2
|
||||||
req.body.model = "anthropic.claude-v2";
|
req.body.model = "anthropic.claude-v2:1";
|
||||||
}
|
}
|
||||||
// TODO: Handle claude-instant
|
// TODO: Handle claude-instant
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
import { RequestHandler, Router } from "express";
|
||||||
|
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||||
|
import { config } from "../config";
|
||||||
|
import { keyPool } from "../shared/key-management";
|
||||||
|
import {
|
||||||
|
ModelFamily,
|
||||||
|
AzureOpenAIModelFamily,
|
||||||
|
getAzureOpenAIModelFamily,
|
||||||
|
} from "../shared/models";
|
||||||
|
import { logger } from "../logger";
|
||||||
|
import { KNOWN_OPENAI_MODELS } from "./openai";
|
||||||
|
import { createQueueMiddleware } from "./queue";
|
||||||
|
import { ipLimiter } from "./rate-limit";
|
||||||
|
import { handleProxyError } from "./middleware/common";
|
||||||
|
import {
|
||||||
|
applyQuotaLimits,
|
||||||
|
blockZoomerOrigins,
|
||||||
|
createOnProxyReqHandler,
|
||||||
|
createPreprocessorMiddleware,
|
||||||
|
finalizeSignedRequest,
|
||||||
|
limitCompletions,
|
||||||
|
stripHeaders,
|
||||||
|
} from "./middleware/request";
|
||||||
|
import {
|
||||||
|
createOnProxyResHandler,
|
||||||
|
ProxyResHandlerWithBody,
|
||||||
|
} from "./middleware/response";
|
||||||
|
import { addAzureKey } from "./middleware/request/add-azure-key";
|
||||||
|
|
||||||
|
let modelsCache: any = null;
|
||||||
|
let modelsCacheTime = 0;
|
||||||
|
|
||||||
|
function getModelsResponse() {
|
||||||
|
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||||
|
return modelsCache;
|
||||||
|
}
|
||||||
|
|
||||||
|
let available = new Set<AzureOpenAIModelFamily>();
|
||||||
|
for (const key of keyPool.list()) {
|
||||||
|
if (key.isDisabled || key.service !== "azure") continue;
|
||||||
|
key.modelFamilies.forEach((family) =>
|
||||||
|
available.add(family as AzureOpenAIModelFamily)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||||
|
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||||
|
|
||||||
|
const models = KNOWN_OPENAI_MODELS.map((id) => ({
|
||||||
|
id,
|
||||||
|
object: "model",
|
||||||
|
created: new Date().getTime(),
|
||||||
|
owned_by: "azure",
|
||||||
|
permission: [
|
||||||
|
{
|
||||||
|
id: "modelperm-" + id,
|
||||||
|
object: "model_permission",
|
||||||
|
created: new Date().getTime(),
|
||||||
|
organization: "*",
|
||||||
|
group: null,
|
||||||
|
is_blocking: false,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
root: id,
|
||||||
|
parent: null,
|
||||||
|
})).filter((model) => available.has(getAzureOpenAIModelFamily(model.id)));
|
||||||
|
|
||||||
|
modelsCache = { object: "list", data: models };
|
||||||
|
modelsCacheTime = new Date().getTime();
|
||||||
|
|
||||||
|
return modelsCache;
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||||
|
res.status(200).json(getModelsResponse());
|
||||||
|
};
|
||||||
|
|
||||||
|
const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async (
|
||||||
|
_proxyRes,
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
body
|
||||||
|
) => {
|
||||||
|
if (typeof body !== "object") {
|
||||||
|
throw new Error("Expected body to be an object");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config.promptLogging) {
|
||||||
|
const host = req.get("host");
|
||||||
|
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.tokenizerInfo) {
|
||||||
|
body.proxy_tokenizer = req.tokenizerInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
res.status(200).json(body);
|
||||||
|
};
|
||||||
|
|
||||||
|
const azureOpenAIProxy = createQueueMiddleware({
|
||||||
|
beforeProxy: addAzureKey,
|
||||||
|
proxyMiddleware: createProxyMiddleware({
|
||||||
|
target: "will be set by router",
|
||||||
|
router: (req) => {
|
||||||
|
if (!req.signedRequest) throw new Error("signedRequest not set");
|
||||||
|
const { hostname, path } = req.signedRequest;
|
||||||
|
return `https://${hostname}${path}`;
|
||||||
|
},
|
||||||
|
changeOrigin: true,
|
||||||
|
selfHandleResponse: true,
|
||||||
|
logger,
|
||||||
|
on: {
|
||||||
|
proxyReq: createOnProxyReqHandler({
|
||||||
|
pipeline: [
|
||||||
|
applyQuotaLimits,
|
||||||
|
limitCompletions,
|
||||||
|
blockZoomerOrigins,
|
||||||
|
stripHeaders,
|
||||||
|
finalizeSignedRequest,
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
|
||||||
|
error: handleProxyError,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const azureOpenAIRouter = Router();
|
||||||
|
azureOpenAIRouter.get("/v1/models", handleModelRequest);
|
||||||
|
azureOpenAIRouter.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
ipLimiter,
|
||||||
|
createPreprocessorMiddleware({
|
||||||
|
inApi: "openai",
|
||||||
|
outApi: "openai",
|
||||||
|
service: "azure",
|
||||||
|
}),
|
||||||
|
azureOpenAIProxy
|
||||||
|
);
|
||||||
|
|
||||||
|
export const azure = azureOpenAIRouter;
|
|
@ -59,7 +59,7 @@ export function writeErrorResponse(
|
||||||
res.write(`data: [DONE]\n\n`);
|
res.write(`data: [DONE]\n\n`);
|
||||||
res.end();
|
res.end();
|
||||||
} else {
|
} else {
|
||||||
if (req.tokenizerInfo && errorPayload.error) {
|
if (req.tokenizerInfo && typeof errorPayload.error === "object") {
|
||||||
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
|
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
|
||||||
}
|
}
|
||||||
res.status(statusCode).json(errorPayload);
|
res.status(statusCode).json(errorPayload);
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
import { AzureOpenAIKey, keyPool } from "../../../shared/key-management";
|
||||||
|
import { RequestPreprocessor } from ".";
|
||||||
|
|
||||||
|
export const addAzureKey: RequestPreprocessor = (req) => {
|
||||||
|
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
|
||||||
|
const serviceValid = req.service === "azure";
|
||||||
|
if (!apisValid || !serviceValid) {
|
||||||
|
throw new Error("addAzureKey called on invalid request");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!req.body?.model) {
|
||||||
|
throw new Error("You must specify a model with your request.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const model = req.body.model.startsWith("azure-")
|
||||||
|
? req.body.model
|
||||||
|
: `azure-${req.body.model}`;
|
||||||
|
|
||||||
|
req.key = keyPool.get(model);
|
||||||
|
req.body.model = model;
|
||||||
|
|
||||||
|
req.log.info(
|
||||||
|
{ key: req.key.hash, model },
|
||||||
|
"Assigned Azure OpenAI key to request"
|
||||||
|
);
|
||||||
|
|
||||||
|
const cred = req.key as AzureOpenAIKey;
|
||||||
|
const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred);
|
||||||
|
|
||||||
|
req.signedRequest = {
|
||||||
|
method: "POST",
|
||||||
|
protocol: "https:",
|
||||||
|
hostname: `${resourceName}.openai.azure.com`,
|
||||||
|
path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`,
|
||||||
|
headers: {
|
||||||
|
["host"]: `${resourceName}.openai.azure.com`,
|
||||||
|
["content-type"]: "application/json",
|
||||||
|
["api-key"]: apiKey,
|
||||||
|
},
|
||||||
|
body: JSON.stringify(req.body),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
function getCredentialsFromKey(key: AzureOpenAIKey) {
|
||||||
|
const [resourceName, deploymentId, apiKey] = key.key.split(":");
|
||||||
|
if (!resourceName || !deploymentId || !apiKey) {
|
||||||
|
throw new Error("Assigned Azure OpenAI key is not in the correct format.");
|
||||||
|
}
|
||||||
|
return { resourceName, deploymentId, apiKey };
|
||||||
|
}
|
|
@ -80,6 +80,10 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||||
`?key=${assignedKey.key}`
|
`?key=${assignedKey.key}`
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
|
case "azure":
|
||||||
|
const azureKey = assignedKey.key;
|
||||||
|
proxyReq.setHeader("api-key", azureKey);
|
||||||
|
break;
|
||||||
case "aws":
|
case "aws":
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
|
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import type { ProxyRequestMiddleware } from ".";
|
import type { ProxyRequestMiddleware } from ".";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For AWS requests, the body is signed earlier in the request pipeline, before
|
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
|
||||||
* the proxy middleware. This function just assigns the path and headers to the
|
* before the proxy middleware. This function just assigns the path and headers
|
||||||
* proxy request.
|
* to the proxy request.
|
||||||
*/
|
*/
|
||||||
export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||||
if (!req.signedRequest) {
|
if (!req.signedRequest) {
|
||||||
throw new Error("Expected req.signedRequest to be set");
|
throw new Error("Expected req.signedRequest to be set");
|
||||||
}
|
}
|
|
@ -22,7 +22,7 @@ export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
|
||||||
export { addAnthropicPreamble } from "./add-anthropic-preamble";
|
export { addAnthropicPreamble } from "./add-anthropic-preamble";
|
||||||
export { blockZoomerOrigins } from "./block-zoomer-origins";
|
export { blockZoomerOrigins } from "./block-zoomer-origins";
|
||||||
export { finalizeBody } from "./finalize-body";
|
export { finalizeBody } from "./finalize-body";
|
||||||
export { finalizeAwsRequest } from "./finalize-aws-request";
|
export { finalizeSignedRequest } from "./finalize-signed-request";
|
||||||
export { limitCompletions } from "./limit-completions";
|
export { limitCompletions } from "./limit-completions";
|
||||||
export { stripHeaders } from "./strip-headers";
|
export { stripHeaders } from "./strip-headers";
|
||||||
|
|
||||||
|
|
|
@ -289,15 +289,17 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||||
switch (service) {
|
switch (service) {
|
||||||
case "openai":
|
case "openai":
|
||||||
case "google-palm":
|
case "google-palm":
|
||||||
if (errorPayload.error?.code === "content_policy_violation") {
|
case "azure":
|
||||||
errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`;
|
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||||
|
if (filteredCodes.includes(errorPayload.error?.code)) {
|
||||||
|
errorPayload.proxy_note = `Request was filtered by the upstream API's content moderation system. Modify your prompt and try again.`;
|
||||||
refundLastAttempt(req);
|
refundLastAttempt(req);
|
||||||
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
|
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
|
||||||
// For some reason, some models return this 400 error instead of the
|
// For some reason, some models return this 400 error instead of the
|
||||||
// same 429 billing error that other models return.
|
// same 429 billing error that other models return.
|
||||||
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||||
} else {
|
} else {
|
||||||
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
|
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
|
@ -342,7 +344,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||||
handleAwsRateLimitError(req, errorPayload);
|
handleAwsRateLimitError(req, errorPayload);
|
||||||
break;
|
break;
|
||||||
case "google-palm":
|
case "google-palm":
|
||||||
throw new Error("Rate limit handling not implemented for PaLM");
|
case "azure":
|
||||||
|
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
assertNever(service);
|
assertNever(service);
|
||||||
}
|
}
|
||||||
|
@ -369,6 +373,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||||
case "aws":
|
case "aws":
|
||||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||||
break;
|
break;
|
||||||
|
case "azure":
|
||||||
|
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
assertNever(service);
|
assertNever(service);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ type SSEMessageTransformerOptions = TransformOptions & {
|
||||||
export class SSEMessageTransformer extends Transform {
|
export class SSEMessageTransformer extends Transform {
|
||||||
private lastPosition: number;
|
private lastPosition: number;
|
||||||
private msgCount: number;
|
private msgCount: number;
|
||||||
|
private readonly inputFormat: APIFormat;
|
||||||
private readonly transformFn: StreamingCompletionTransformer;
|
private readonly transformFn: StreamingCompletionTransformer;
|
||||||
private readonly log;
|
private readonly log;
|
||||||
private readonly fallbackId: string;
|
private readonly fallbackId: string;
|
||||||
|
@ -42,6 +43,7 @@ export class SSEMessageTransformer extends Transform {
|
||||||
options.inputFormat,
|
options.inputFormat,
|
||||||
options.inputApiVersion
|
options.inputApiVersion
|
||||||
);
|
);
|
||||||
|
this.inputFormat = options.inputFormat;
|
||||||
this.fallbackId = options.requestId;
|
this.fallbackId = options.requestId;
|
||||||
this.fallbackModel = options.requestedModel;
|
this.fallbackModel = options.requestedModel;
|
||||||
this.log.debug(
|
this.log.debug(
|
||||||
|
@ -67,6 +69,17 @@ export class SSEMessageTransformer extends Transform {
|
||||||
});
|
});
|
||||||
this.lastPosition = newPosition;
|
this.lastPosition = newPosition;
|
||||||
|
|
||||||
|
// Special case for Azure OpenAI, which is 99% the same as OpenAI but
|
||||||
|
// sometimes emits an extra event at the beginning of the stream with the
|
||||||
|
// content moderation system's response to the prompt. A lot of frontends
|
||||||
|
// don't expect this and neither does our event aggregator so we drop it.
|
||||||
|
if (this.inputFormat === "openai" && this.msgCount <= 1) {
|
||||||
|
if (originalMessage.includes("prompt_filter_results")) {
|
||||||
|
this.log.debug("Dropping Azure OpenAI content moderation SSE event");
|
||||||
|
return callback();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
this.emit("originalMessage", originalMessage);
|
this.emit("originalMessage", originalMessage);
|
||||||
|
|
||||||
// Some events may not be transformed, e.g. ping events
|
// Some events may not be transformed, e.g. ping events
|
||||||
|
|
|
@ -24,7 +24,7 @@ import {
|
||||||
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
|
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
|
||||||
|
|
||||||
// https://platform.openai.com/docs/models/overview
|
// https://platform.openai.com/docs/models/overview
|
||||||
const KNOWN_MODELS = [
|
export const KNOWN_OPENAI_MODELS = [
|
||||||
"gpt-4-1106-preview",
|
"gpt-4-1106-preview",
|
||||||
"gpt-4-vision-preview",
|
"gpt-4-vision-preview",
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
|
@ -46,7 +46,7 @@ const KNOWN_MODELS = [
|
||||||
let modelsCache: any = null;
|
let modelsCache: any = null;
|
||||||
let modelsCacheTime = 0;
|
let modelsCacheTime = 0;
|
||||||
|
|
||||||
export function generateModelList(models = KNOWN_MODELS) {
|
export function generateModelList(models = KNOWN_OPENAI_MODELS) {
|
||||||
let available = new Set<OpenAIModelFamily>();
|
let available = new Set<OpenAIModelFamily>();
|
||||||
for (const key of keyPool.list()) {
|
for (const key of keyPool.list()) {
|
||||||
if (key.isDisabled || key.service !== "openai") continue;
|
if (key.isDisabled || key.service !== "openai") continue;
|
||||||
|
|
|
@ -26,6 +26,7 @@ import { assertNever } from "../shared/utils";
|
||||||
import { logger } from "../logger";
|
import { logger } from "../logger";
|
||||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||||
import { RequestPreprocessor } from "./middleware/request";
|
import { RequestPreprocessor } from "./middleware/request";
|
||||||
|
import { handleProxyError } from "./middleware/common";
|
||||||
|
|
||||||
const queue: Request[] = [];
|
const queue: Request[] = [];
|
||||||
const log = logger.child({ module: "request-queue" });
|
const log = logger.child({ module: "request-queue" });
|
||||||
|
@ -34,7 +35,7 @@ const log = logger.child({ module: "request-queue" });
|
||||||
const AGNAI_CONCURRENCY_LIMIT = 5;
|
const AGNAI_CONCURRENCY_LIMIT = 5;
|
||||||
/** Maximum number of queue slots for individual users. */
|
/** Maximum number of queue slots for individual users. */
|
||||||
const USER_CONCURRENCY_LIMIT = 1;
|
const USER_CONCURRENCY_LIMIT = 1;
|
||||||
const MIN_HEARTBEAT_SIZE = 512;
|
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
|
||||||
const MAX_HEARTBEAT_SIZE =
|
const MAX_HEARTBEAT_SIZE =
|
||||||
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
|
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
|
||||||
const HEARTBEAT_INTERVAL =
|
const HEARTBEAT_INTERVAL =
|
||||||
|
@ -358,12 +359,16 @@ export function createQueueMiddleware({
|
||||||
return (req, res, next) => {
|
return (req, res, next) => {
|
||||||
req.proceed = async () => {
|
req.proceed = async () => {
|
||||||
if (beforeProxy) {
|
if (beforeProxy) {
|
||||||
// Hack to let us run asynchronous middleware before the
|
try {
|
||||||
// http-proxy-middleware handler. This is used to sign AWS requests
|
// Hack to let us run asynchronous middleware before the
|
||||||
// before they are proxied, as the signing is asynchronous.
|
// http-proxy-middleware handler. This is used to sign AWS requests
|
||||||
// Unlike RequestPreprocessors, this runs every time the request is
|
// before they are proxied, as the signing is asynchronous.
|
||||||
// dequeued, not just the first time.
|
// Unlike RequestPreprocessors, this runs every time the request is
|
||||||
await beforeProxy(req);
|
// dequeued, not just the first time.
|
||||||
|
await beforeProxy(req);
|
||||||
|
} catch (err) {
|
||||||
|
return handleProxyError(err, req, res);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
proxyMiddleware(req, res, next);
|
proxyMiddleware(req, res, next);
|
||||||
};
|
};
|
||||||
|
|
|
@ -6,6 +6,7 @@ import { openaiImage } from "./openai-image";
|
||||||
import { anthropic } from "./anthropic";
|
import { anthropic } from "./anthropic";
|
||||||
import { googlePalm } from "./palm";
|
import { googlePalm } from "./palm";
|
||||||
import { aws } from "./aws";
|
import { aws } from "./aws";
|
||||||
|
import { azure } from "./azure";
|
||||||
|
|
||||||
const proxyRouter = express.Router();
|
const proxyRouter = express.Router();
|
||||||
proxyRouter.use((req, _res, next) => {
|
proxyRouter.use((req, _res, next) => {
|
||||||
|
@ -32,6 +33,7 @@ proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||||
proxyRouter.use("/google-palm", addV1, googlePalm);
|
proxyRouter.use("/google-palm", addV1, googlePalm);
|
||||||
proxyRouter.use("/aws/claude", addV1, aws);
|
proxyRouter.use("/aws/claude", addV1, aws);
|
||||||
|
proxyRouter.use("/azure/openai", addV1, azure);
|
||||||
// Redirect browser requests to the homepage.
|
// Redirect browser requests to the homepage.
|
||||||
proxyRouter.get("*", (req, res, next) => {
|
proxyRouter.get("*", (req, res, next) => {
|
||||||
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
||||||
|
|
|
@ -26,46 +26,23 @@ type AnthropicAPIError = {
|
||||||
type UpdateFn = typeof AnthropicKeyProvider.prototype.update;
|
type UpdateFn = typeof AnthropicKeyProvider.prototype.update;
|
||||||
|
|
||||||
export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||||
private readonly updateKey: UpdateFn;
|
|
||||||
|
|
||||||
constructor(keys: AnthropicKey[], updateKey: UpdateFn) {
|
constructor(keys: AnthropicKey[], updateKey: UpdateFn) {
|
||||||
super(keys, {
|
super(keys, {
|
||||||
service: "anthropic",
|
service: "anthropic",
|
||||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||||
|
updateKey,
|
||||||
});
|
});
|
||||||
this.updateKey = updateKey;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected async checkKey(key: AnthropicKey) {
|
protected async testKeyOrFail(key: AnthropicKey) {
|
||||||
if (key.isDisabled) {
|
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
|
||||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
const updates = { isPozzed: pozzed };
|
||||||
this.scheduleNextCheck();
|
this.updateKey(key.hash, updates);
|
||||||
return;
|
this.log.info(
|
||||||
}
|
{ key: key.hash, models: key.modelFamilies },
|
||||||
|
"Checked key."
|
||||||
this.log.debug({ key: key.hash }, "Checking key...");
|
);
|
||||||
let isInitialCheck = !key.lastChecked;
|
|
||||||
try {
|
|
||||||
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
|
|
||||||
const updates = { isPozzed: pozzed };
|
|
||||||
this.updateKey(key.hash, updates);
|
|
||||||
this.log.info(
|
|
||||||
{ key: key.hash, models: key.modelFamilies },
|
|
||||||
"Key check complete."
|
|
||||||
);
|
|
||||||
} catch (error) {
|
|
||||||
// touch the key so we don't check it again for a while
|
|
||||||
this.updateKey(key.hash, {});
|
|
||||||
this.handleAxiosError(key, error as AxiosError);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.lastCheck = Date.now();
|
|
||||||
// Only enqueue the next check if this wasn't a startup check, since those
|
|
||||||
// are batched together elsewhere.
|
|
||||||
if (!isInitialCheck) {
|
|
||||||
this.scheduleNextCheck();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
|
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
|
||||||
|
@ -84,6 +61,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||||
{ key: key.hash, error: error.message },
|
{ key: key.hash, error: error.message },
|
||||||
"Key is rate limited. Rechecking in 10 seconds."
|
"Key is rate limited. Rechecking in 10 seconds."
|
||||||
);
|
);
|
||||||
|
0;
|
||||||
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
||||||
this.updateKey(key.hash, { lastChecked: next });
|
this.updateKey(key.hash, { lastChecked: next });
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -32,58 +32,36 @@ type GetLoggingConfigResponse = {
|
||||||
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
|
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
|
||||||
|
|
||||||
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||||
private readonly updateKey: UpdateFn;
|
|
||||||
|
|
||||||
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
|
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
|
||||||
super(keys, {
|
super(keys, {
|
||||||
service: "aws",
|
service: "aws",
|
||||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||||
|
updateKey,
|
||||||
});
|
});
|
||||||
this.updateKey = updateKey;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected async checkKey(key: AwsBedrockKey) {
|
protected async testKeyOrFail(key: AwsBedrockKey) {
|
||||||
if (key.isDisabled) {
|
// Only check models on startup. For now all models must be available to
|
||||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
// the proxy because we don't route requests to different keys.
|
||||||
this.scheduleNextCheck();
|
const modelChecks: Promise<unknown>[] = [];
|
||||||
return;
|
const isInitialCheck = !key.lastChecked;
|
||||||
|
if (isInitialCheck) {
|
||||||
|
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
|
||||||
|
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.log.debug({ key: key.hash }, "Checking key...");
|
await Promise.all(modelChecks);
|
||||||
let isInitialCheck = !key.lastChecked;
|
await this.checkLoggingConfiguration(key);
|
||||||
try {
|
|
||||||
// Only check models on startup. For now all models must be available to
|
|
||||||
// the proxy because we don't route requests to different keys.
|
|
||||||
const modelChecks: Promise<unknown>[] = [];
|
|
||||||
if (isInitialCheck) {
|
|
||||||
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
|
|
||||||
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
|
|
||||||
}
|
|
||||||
|
|
||||||
await Promise.all(modelChecks);
|
this.log.info(
|
||||||
await this.checkLoggingConfiguration(key);
|
{
|
||||||
|
key: key.hash,
|
||||||
this.log.info(
|
models: key.modelFamilies,
|
||||||
{
|
logged: key.awsLoggingStatus,
|
||||||
key: key.hash,
|
},
|
||||||
models: key.modelFamilies,
|
"Checked key."
|
||||||
logged: key.awsLoggingStatus,
|
);
|
||||||
},
|
|
||||||
"Key check complete."
|
|
||||||
);
|
|
||||||
} catch (error) {
|
|
||||||
this.handleAxiosError(key, error as AxiosError);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.updateKey(key.hash, {});
|
|
||||||
|
|
||||||
this.lastCheck = Date.now();
|
|
||||||
// Only enqueue the next check if this wasn't a startup check, since those
|
|
||||||
// are batched together elsewhere.
|
|
||||||
if (!isInitialCheck) {
|
|
||||||
this.scheduleNextCheck();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {
|
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
import axios, { AxiosError } from "axios";
|
||||||
|
import { KeyCheckerBase } from "../key-checker-base";
|
||||||
|
import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
|
||||||
|
import { getAzureOpenAIModelFamily } from "../../models";
|
||||||
|
|
||||||
|
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||||
|
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
|
||||||
|
const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com";
|
||||||
|
const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) =>
|
||||||
|
`https://${AZURE_HOST.replace(
|
||||||
|
"%RESOURCE_NAME%",
|
||||||
|
resourceName
|
||||||
|
)}/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`;
|
||||||
|
|
||||||
|
type AzureError = {
|
||||||
|
error: {
|
||||||
|
message: string;
|
||||||
|
type: string | null;
|
||||||
|
param: string;
|
||||||
|
code: string;
|
||||||
|
status: number;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
type UpdateFn = typeof AzureOpenAIKeyProvider.prototype.update;
|
||||||
|
|
||||||
|
export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
||||||
|
constructor(keys: AzureOpenAIKey[], updateKey: UpdateFn) {
|
||||||
|
super(keys, {
|
||||||
|
service: "azure",
|
||||||
|
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||||
|
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||||
|
recurringChecksEnabled: false,
|
||||||
|
updateKey,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
protected async testKeyOrFail(key: AzureOpenAIKey) {
|
||||||
|
const model = await this.testModel(key);
|
||||||
|
this.log.info(
|
||||||
|
{ key: key.hash, deploymentModel: model },
|
||||||
|
"Checked key."
|
||||||
|
);
|
||||||
|
this.updateKey(key.hash, { modelFamilies: [model] });
|
||||||
|
}
|
||||||
|
|
||||||
|
// provided api-key header isn't valid (401)
|
||||||
|
// {
|
||||||
|
// "error": {
|
||||||
|
// "code": "401",
|
||||||
|
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// api key correct but deployment id is wrong (404)
|
||||||
|
// {
|
||||||
|
// "error": {
|
||||||
|
// "code": "DeploymentNotFound",
|
||||||
|
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// resource name is wrong (node will throw ENOTFOUND)
|
||||||
|
|
||||||
|
// rate limited (429)
|
||||||
|
// TODO: try to reproduce this
|
||||||
|
|
||||||
|
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
|
||||||
|
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
|
||||||
|
const data = error.response.data;
|
||||||
|
const status = data.error.status;
|
||||||
|
const errorType = data.error.code || data.error.type;
|
||||||
|
switch (errorType) {
|
||||||
|
case "DeploymentNotFound":
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash, errorType, error: error.response.data },
|
||||||
|
"Key is revoked or deployment ID is incorrect. Disabling key."
|
||||||
|
);
|
||||||
|
return this.updateKey(key.hash, {
|
||||||
|
isDisabled: true,
|
||||||
|
isRevoked: true,
|
||||||
|
});
|
||||||
|
case "401":
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash, errorType, error: error.response.data },
|
||||||
|
"Key is disabled or incorrect. Disabling key."
|
||||||
|
);
|
||||||
|
return this.updateKey(key.hash, {
|
||||||
|
isDisabled: true,
|
||||||
|
isRevoked: true,
|
||||||
|
});
|
||||||
|
default:
|
||||||
|
this.log.error(
|
||||||
|
{ key: key.hash, errorType, error: error.response.data, status },
|
||||||
|
"Unknown Azure API error while checking key. Please report this."
|
||||||
|
);
|
||||||
|
return this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const { response, code } = error;
|
||||||
|
if (code === "ENOTFOUND") {
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash, error: error.message },
|
||||||
|
"Resource name is probably incorrect. Disabling key."
|
||||||
|
);
|
||||||
|
return this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
const { headers, status, data } = response ?? {};
|
||||||
|
this.log.error(
|
||||||
|
{ key: key.hash, status, headers, data, error: error.message },
|
||||||
|
"Network error while checking key; trying this key again in a minute."
|
||||||
|
);
|
||||||
|
const oneMinute = 60 * 1000;
|
||||||
|
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||||
|
this.updateKey(key.hash, { lastChecked: next });
|
||||||
|
}
|
||||||
|
|
||||||
|
private async testModel(key: AzureOpenAIKey) {
|
||||||
|
const { apiKey, deploymentId, resourceName } =
|
||||||
|
AzureOpenAIKeyChecker.getCredentialsFromKey(key);
|
||||||
|
const url = POST_CHAT_COMPLETIONS(resourceName, deploymentId);
|
||||||
|
const testRequest = {
|
||||||
|
max_tokens: 1,
|
||||||
|
stream: false,
|
||||||
|
messages: [{ role: "user", content: "" }],
|
||||||
|
};
|
||||||
|
const { data } = await axios.post(url, testRequest, {
|
||||||
|
headers: { "Content-Type": "application/json", "api-key": apiKey },
|
||||||
|
});
|
||||||
|
|
||||||
|
return getAzureOpenAIModelFamily(data.model);
|
||||||
|
}
|
||||||
|
|
||||||
|
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
|
||||||
|
const data = error.response?.data as any;
|
||||||
|
return data?.error?.code || data?.error?.type;
|
||||||
|
}
|
||||||
|
|
||||||
|
static getCredentialsFromKey(key: AzureOpenAIKey) {
|
||||||
|
const [resourceName, deploymentId, apiKey] = key.key.split(":");
|
||||||
|
if (!resourceName || !deploymentId || !apiKey) {
|
||||||
|
throw new Error(
|
||||||
|
"Invalid Azure credential format. Refer to .env.example and ensure your credentials are in the format RESOURCE_NAME:DEPLOYMENT_ID:API_KEY with commas between each credential set."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return { resourceName, deploymentId, apiKey };
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,212 @@
|
||||||
|
import crypto from "crypto";
|
||||||
|
import { Key, KeyProvider } from "..";
|
||||||
|
import { config } from "../../../config";
|
||||||
|
import { logger } from "../../../logger";
|
||||||
|
import type { AzureOpenAIModelFamily } from "../../models";
|
||||||
|
import { getAzureOpenAIModelFamily } from "../../models";
|
||||||
|
import { OpenAIModel } from "../openai/provider";
|
||||||
|
import { AzureOpenAIKeyChecker } from "./checker";
|
||||||
|
import { AwsKeyChecker } from "../aws/checker";
|
||||||
|
|
||||||
|
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
|
||||||
|
|
||||||
|
type AzureOpenAIKeyUsage = {
|
||||||
|
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
|
||||||
|
readonly service: "azure";
|
||||||
|
readonly modelFamilies: AzureOpenAIModelFamily[];
|
||||||
|
/** The time at which this key was last rate limited. */
|
||||||
|
rateLimitedAt: number;
|
||||||
|
/** The time until which this key is rate limited. */
|
||||||
|
rateLimitedUntil: number;
|
||||||
|
contentFiltering: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||||
|
* while we wait for other concurrent requests to finish.
|
||||||
|
*/
|
||||||
|
const RATE_LIMIT_LOCKOUT = 4000;
|
||||||
|
/**
|
||||||
|
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||||
|
* to be used again. This is to prevent the queue from flooding a key with too
|
||||||
|
* many requests while we wait to learn whether previous ones succeeded.
|
||||||
|
*/
|
||||||
|
const KEY_REUSE_DELAY = 250;
|
||||||
|
|
||||||
|
export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
||||||
|
readonly service = "azure";
|
||||||
|
|
||||||
|
private keys: AzureOpenAIKey[] = [];
|
||||||
|
private checker?: AzureOpenAIKeyChecker;
|
||||||
|
private log = logger.child({ module: "key-provider", service: this.service });
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
const keyConfig = config.azureCredentials;
|
||||||
|
if (!keyConfig) {
|
||||||
|
this.log.warn(
|
||||||
|
"AZURE_CREDENTIALS is not set. Azure OpenAI API will not be available."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let bareKeys: string[];
|
||||||
|
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||||
|
for (const key of bareKeys) {
|
||||||
|
const newKey: AzureOpenAIKey = {
|
||||||
|
key,
|
||||||
|
service: this.service,
|
||||||
|
modelFamilies: ["azure-gpt4"],
|
||||||
|
isDisabled: false,
|
||||||
|
isRevoked: false,
|
||||||
|
promptCount: 0,
|
||||||
|
lastUsed: 0,
|
||||||
|
rateLimitedAt: 0,
|
||||||
|
rateLimitedUntil: 0,
|
||||||
|
contentFiltering: false,
|
||||||
|
hash: `azu-${crypto
|
||||||
|
.createHash("sha256")
|
||||||
|
.update(key)
|
||||||
|
.digest("hex")
|
||||||
|
.slice(0, 8)}`,
|
||||||
|
lastChecked: 0,
|
||||||
|
"azure-turboTokens": 0,
|
||||||
|
"azure-gpt4Tokens": 0,
|
||||||
|
"azure-gpt4-32kTokens": 0,
|
||||||
|
"azure-gpt4-turboTokens": 0,
|
||||||
|
};
|
||||||
|
this.keys.push(newKey);
|
||||||
|
}
|
||||||
|
this.log.info({ keyCount: this.keys.length }, "Loaded Azure OpenAI keys.");
|
||||||
|
}
|
||||||
|
|
||||||
|
public init() {
|
||||||
|
if (config.checkKeys) {
|
||||||
|
this.checker = new AzureOpenAIKeyChecker(
|
||||||
|
this.keys,
|
||||||
|
this.update.bind(this)
|
||||||
|
);
|
||||||
|
this.checker.start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public list() {
|
||||||
|
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||||
|
}
|
||||||
|
|
||||||
|
public get(model: AzureOpenAIModel) {
|
||||||
|
const neededFamily = getAzureOpenAIModelFamily(model);
|
||||||
|
const availableKeys = this.keys.filter(
|
||||||
|
(k) => !k.isDisabled && k.modelFamilies.includes(neededFamily)
|
||||||
|
);
|
||||||
|
if (availableKeys.length === 0) {
|
||||||
|
throw new Error(`No keys available for model family '${neededFamily}'.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// (largely copied from the OpenAI provider, without trial key support)
|
||||||
|
// Select a key, from highest priority to lowest priority:
|
||||||
|
// 1. Keys which are not rate limited
|
||||||
|
// a. If all keys were rate limited recently, select the least-recently
|
||||||
|
// rate limited key.
|
||||||
|
// 3. Keys which have not been used in the longest time
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
|
||||||
|
const keysByPriority = availableKeys.sort((a, b) => {
|
||||||
|
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||||
|
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||||
|
|
||||||
|
if (aRateLimited && !bRateLimited) return 1;
|
||||||
|
if (!aRateLimited && bRateLimited) return -1;
|
||||||
|
if (aRateLimited && bRateLimited) {
|
||||||
|
return a.rateLimitedAt - b.rateLimitedAt;
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.lastUsed - b.lastUsed;
|
||||||
|
});
|
||||||
|
|
||||||
|
const selectedKey = keysByPriority[0];
|
||||||
|
selectedKey.lastUsed = now;
|
||||||
|
this.throttle(selectedKey.hash);
|
||||||
|
return { ...selectedKey };
|
||||||
|
}
|
||||||
|
|
||||||
|
public disable(key: AzureOpenAIKey) {
|
||||||
|
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||||
|
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||||
|
keyFromPool.isDisabled = true;
|
||||||
|
this.log.warn({ key: key.hash }, "Key disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
public update(hash: string, update: Partial<AzureOpenAIKey>) {
|
||||||
|
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||||
|
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||||
|
}
|
||||||
|
|
||||||
|
public available() {
|
||||||
|
return this.keys.filter((k) => !k.isDisabled).length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||||
|
const key = this.keys.find((k) => k.hash === hash);
|
||||||
|
if (!key) return;
|
||||||
|
key.promptCount++;
|
||||||
|
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: all of this shit is duplicate code
|
||||||
|
|
||||||
|
public getLockoutPeriod() {
|
||||||
|
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||||
|
// Don't lock out if there are no keys available or the queue will stall.
|
||||||
|
// Just let it through so the add-key middleware can throw an error.
|
||||||
|
if (activeKeys.length === 0) return 0;
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||||
|
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||||
|
|
||||||
|
if (anyNotRateLimited) return 0;
|
||||||
|
|
||||||
|
// If all keys are rate-limited, return time until the first key is ready.
|
||||||
|
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is called when we receive a 429, which means there are already five
|
||||||
|
* concurrent requests running on this key. We don't have any information on
|
||||||
|
* when these requests will resolve, so all we can do is wait a bit and try
|
||||||
|
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||||
|
* retrying in order to give the other requests a chance to finish.
|
||||||
|
*/
|
||||||
|
public markRateLimited(keyHash: string) {
|
||||||
|
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||||
|
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||||
|
const now = Date.now();
|
||||||
|
key.rateLimitedAt = now;
|
||||||
|
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||||
|
}
|
||||||
|
|
||||||
|
public recheck() {
|
||||||
|
this.keys.forEach(({ hash }) =>
|
||||||
|
this.update(hash, { lastChecked: 0, isDisabled: false })
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||||
|
* prevent it from being immediately assigned to another request before the
|
||||||
|
* current one can be dispatched.
|
||||||
|
**/
|
||||||
|
private throttle(hash: string) {
|
||||||
|
const now = Date.now();
|
||||||
|
const key = this.keys.find((k) => k.hash === hash)!;
|
||||||
|
|
||||||
|
const currentRateLimit = key.rateLimitedUntil;
|
||||||
|
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||||
|
|
||||||
|
key.rateLimitedAt = now;
|
||||||
|
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ import { OpenAIModel } from "./openai/provider";
|
||||||
import { AnthropicModel } from "./anthropic/provider";
|
import { AnthropicModel } from "./anthropic/provider";
|
||||||
import { GooglePalmModel } from "./palm/provider";
|
import { GooglePalmModel } from "./palm/provider";
|
||||||
import { AwsBedrockModel } from "./aws/provider";
|
import { AwsBedrockModel } from "./aws/provider";
|
||||||
|
import { AzureOpenAIModel } from "./azure/provider";
|
||||||
import { KeyPool } from "./key-pool";
|
import { KeyPool } from "./key-pool";
|
||||||
import type { ModelFamily } from "../models";
|
import type { ModelFamily } from "../models";
|
||||||
|
|
||||||
|
@ -13,12 +14,18 @@ export type APIFormat =
|
||||||
| "openai-text"
|
| "openai-text"
|
||||||
| "openai-image";
|
| "openai-image";
|
||||||
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
||||||
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
|
export type LLMService =
|
||||||
|
| "openai"
|
||||||
|
| "anthropic"
|
||||||
|
| "google-palm"
|
||||||
|
| "aws"
|
||||||
|
| "azure";
|
||||||
export type Model =
|
export type Model =
|
||||||
| OpenAIModel
|
| OpenAIModel
|
||||||
| AnthropicModel
|
| AnthropicModel
|
||||||
| GooglePalmModel
|
| GooglePalmModel
|
||||||
| AwsBedrockModel;
|
| AwsBedrockModel
|
||||||
|
| AzureOpenAIModel;
|
||||||
|
|
||||||
export interface Key {
|
export interface Key {
|
||||||
/** The API key itself. Never log this, use `hash` instead. */
|
/** The API key itself. Never log this, use `hash` instead. */
|
||||||
|
@ -72,3 +79,4 @@ export { AnthropicKey } from "./anthropic/provider";
|
||||||
export { OpenAIKey } from "./openai/provider";
|
export { OpenAIKey } from "./openai/provider";
|
||||||
export { GooglePalmKey } from "./palm/provider";
|
export { GooglePalmKey } from "./palm/provider";
|
||||||
export { AwsBedrockKey } from "./aws/provider";
|
export { AwsBedrockKey } from "./aws/provider";
|
||||||
|
export { AzureOpenAIKey } from "./azure/provider";
|
||||||
|
|
|
@ -3,14 +3,17 @@ import { logger } from "../../logger";
|
||||||
import { Key } from "./index";
|
import { Key } from "./index";
|
||||||
import { AxiosError } from "axios";
|
import { AxiosError } from "axios";
|
||||||
|
|
||||||
type KeyCheckerOptions = {
|
type KeyCheckerOptions<TKey extends Key = Key> = {
|
||||||
service: string;
|
service: string;
|
||||||
keyCheckPeriod: number;
|
keyCheckPeriod: number;
|
||||||
minCheckInterval: number;
|
minCheckInterval: number;
|
||||||
}
|
recurringChecksEnabled?: boolean;
|
||||||
|
updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||||
|
};
|
||||||
|
|
||||||
export abstract class KeyCheckerBase<TKey extends Key> {
|
export abstract class KeyCheckerBase<TKey extends Key> {
|
||||||
protected readonly service: string;
|
protected readonly service: string;
|
||||||
|
protected readonly RECURRING_CHECKS_ENABLED: boolean;
|
||||||
/** Minimum time in between any two key checks. */
|
/** Minimum time in between any two key checks. */
|
||||||
protected readonly MIN_CHECK_INTERVAL: number;
|
protected readonly MIN_CHECK_INTERVAL: number;
|
||||||
/**
|
/**
|
||||||
|
@ -19,16 +22,19 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||||
* than this.
|
* than this.
|
||||||
*/
|
*/
|
||||||
protected readonly KEY_CHECK_PERIOD: number;
|
protected readonly KEY_CHECK_PERIOD: number;
|
||||||
|
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||||
protected readonly keys: TKey[] = [];
|
protected readonly keys: TKey[] = [];
|
||||||
protected log: pino.Logger;
|
protected log: pino.Logger;
|
||||||
protected timeout?: NodeJS.Timeout;
|
protected timeout?: NodeJS.Timeout;
|
||||||
protected lastCheck = 0;
|
protected lastCheck = 0;
|
||||||
|
|
||||||
protected constructor(keys: TKey[], opts: KeyCheckerOptions) {
|
protected constructor(keys: TKey[], opts: KeyCheckerOptions<TKey>) {
|
||||||
const { service, keyCheckPeriod, minCheckInterval } = opts;
|
const { service, keyCheckPeriod, minCheckInterval } = opts;
|
||||||
this.keys = keys;
|
this.keys = keys;
|
||||||
this.KEY_CHECK_PERIOD = keyCheckPeriod;
|
this.KEY_CHECK_PERIOD = keyCheckPeriod;
|
||||||
this.MIN_CHECK_INTERVAL = minCheckInterval;
|
this.MIN_CHECK_INTERVAL = minCheckInterval;
|
||||||
|
this.RECURRING_CHECKS_ENABLED = opts.recurringChecksEnabled ?? true;
|
||||||
|
this.updateKey = opts.updateKey;
|
||||||
this.service = service;
|
this.service = service;
|
||||||
this.log = logger.child({ module: "key-checker", service });
|
this.log = logger.child({ module: "key-checker", service });
|
||||||
}
|
}
|
||||||
|
@ -52,31 +58,34 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||||
* the minimum check interval.
|
* the minimum check interval.
|
||||||
*/
|
*/
|
||||||
public scheduleNextCheck() {
|
public scheduleNextCheck() {
|
||||||
|
// Gives each concurrent check a correlation ID to make logs less confusing.
|
||||||
const callId = Math.random().toString(36).slice(2, 8);
|
const callId = Math.random().toString(36).slice(2, 8);
|
||||||
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
|
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
|
||||||
const checkLog = this.log.child({ callId, timeoutId });
|
const checkLog = this.log.child({ callId, timeoutId });
|
||||||
|
|
||||||
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
|
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
|
||||||
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
|
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
|
||||||
|
const numEnabled = enabledKeys.length;
|
||||||
|
const numUnchecked = uncheckedKeys.length;
|
||||||
|
|
||||||
clearTimeout(this.timeout);
|
clearTimeout(this.timeout);
|
||||||
|
this.timeout = undefined;
|
||||||
|
|
||||||
if (enabledKeys.length === 0) {
|
if (!numEnabled) {
|
||||||
checkLog.warn("All keys are disabled. Key checker stopping.");
|
checkLog.warn("All keys are disabled. Stopping.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform startup checks for any keys that haven't been checked yet.
|
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
|
||||||
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
|
|
||||||
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
|
if (numUnchecked > 0) {
|
||||||
if (uncheckedKeys.length > 0) {
|
const keycheckBatch = uncheckedKeys.slice(0, 12);
|
||||||
const keysToCheck = uncheckedKeys.slice(0, 12);
|
|
||||||
|
|
||||||
this.timeout = setTimeout(async () => {
|
this.timeout = setTimeout(async () => {
|
||||||
try {
|
try {
|
||||||
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
|
await Promise.all(keycheckBatch.map((key) => this.checkKey(key)));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.log.error({ error }, "Error checking one or more keys.");
|
checkLog.error({ error }, "Error checking one or more keys.");
|
||||||
}
|
}
|
||||||
checkLog.info("Batch complete.");
|
checkLog.info("Batch complete.");
|
||||||
this.scheduleNextCheck();
|
this.scheduleNextCheck();
|
||||||
|
@ -84,11 +93,18 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||||
|
|
||||||
checkLog.info(
|
checkLog.info(
|
||||||
{
|
{
|
||||||
batch: keysToCheck.map((k) => k.hash),
|
batch: keycheckBatch.map((k) => k.hash),
|
||||||
remaining: uncheckedKeys.length - keysToCheck.length,
|
remaining: uncheckedKeys.length - keycheckBatch.length,
|
||||||
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
|
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
|
||||||
},
|
},
|
||||||
"Scheduled batch check."
|
"Scheduled batch of initial checks."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!this.RECURRING_CHECKS_ENABLED) {
|
||||||
|
checkLog.info(
|
||||||
|
"Initial checks complete and recurring checks are disabled for this service. Stopping."
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -106,14 +122,35 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||||
);
|
);
|
||||||
|
|
||||||
const delay = nextCheck - Date.now();
|
const delay = nextCheck - Date.now();
|
||||||
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
|
this.timeout = setTimeout(
|
||||||
|
() => this.checkKey(oldestKey).then(() => this.scheduleNextCheck()),
|
||||||
|
delay
|
||||||
|
);
|
||||||
checkLog.debug(
|
checkLog.debug(
|
||||||
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
|
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
|
||||||
"Scheduled single key check."
|
"Scheduled next recurring check."
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract checkKey(key: TKey): Promise<void>;
|
public async checkKey(key: TKey): Promise<void> {
|
||||||
|
if (key.isDisabled) {
|
||||||
|
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||||
|
this.scheduleNextCheck();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.log.debug({ key: key.hash }, "Checking key...");
|
||||||
|
|
||||||
|
try {
|
||||||
|
await this.testKeyOrFail(key);
|
||||||
|
} catch (error) {
|
||||||
|
this.updateKey(key.hash, {});
|
||||||
|
this.handleAxiosError(key, error as AxiosError);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.lastCheck = Date.now();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract testKeyOrFail(key: TKey): Promise<void>;
|
||||||
|
|
||||||
protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
|
protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import { GooglePalmKeyProvider } from "./palm/provider";
|
||||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||||
import { ModelFamily } from "../models";
|
import { ModelFamily } from "../models";
|
||||||
import { assertNever } from "../utils";
|
import { assertNever } from "../utils";
|
||||||
|
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||||
|
|
||||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||||
|
|
||||||
|
@ -25,6 +26,7 @@ export class KeyPool {
|
||||||
this.keyProviders.push(new AnthropicKeyProvider());
|
this.keyProviders.push(new AnthropicKeyProvider());
|
||||||
this.keyProviders.push(new GooglePalmKeyProvider());
|
this.keyProviders.push(new GooglePalmKeyProvider());
|
||||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||||
|
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||||
}
|
}
|
||||||
|
|
||||||
public init() {
|
public init() {
|
||||||
|
@ -124,6 +126,8 @@ export class KeyPool {
|
||||||
// AWS offers models from a few providers
|
// AWS offers models from a few providers
|
||||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||||
return "aws";
|
return "aws";
|
||||||
|
} else if (model.startsWith("azure")) {
|
||||||
|
return "azure";
|
||||||
}
|
}
|
||||||
throw new Error(`Unknown service for model '${model}'`);
|
throw new Error(`Unknown service for model '${model}'`);
|
||||||
}
|
}
|
||||||
|
@ -142,6 +146,11 @@ export class KeyPool {
|
||||||
return "google-palm";
|
return "google-palm";
|
||||||
case "aws-claude":
|
case "aws-claude":
|
||||||
return "aws";
|
return "aws";
|
||||||
|
case "azure-turbo":
|
||||||
|
case "azure-gpt4":
|
||||||
|
case "azure-gpt4-32k":
|
||||||
|
case "azure-gpt4-turbo":
|
||||||
|
return "azure";
|
||||||
default:
|
default:
|
||||||
assertNever(modelFamily);
|
assertNever(modelFamily);
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,65 +27,41 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
|
||||||
|
|
||||||
export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
||||||
private readonly cloneKey: CloneFn;
|
private readonly cloneKey: CloneFn;
|
||||||
private readonly updateKey: UpdateFn;
|
|
||||||
|
|
||||||
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
|
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
|
||||||
super(keys, {
|
super(keys, {
|
||||||
service: "openai",
|
service: "openai",
|
||||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||||
|
recurringChecksEnabled: false,
|
||||||
|
updateKey,
|
||||||
});
|
});
|
||||||
this.cloneKey = cloneFn;
|
this.cloneKey = cloneFn;
|
||||||
this.updateKey = updateKey;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected async checkKey(key: OpenAIKey) {
|
protected async testKeyOrFail(key: OpenAIKey) {
|
||||||
if (key.isDisabled) {
|
// We only need to check for provisioned models on the initial check.
|
||||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
const isInitialCheck = !key.lastChecked;
|
||||||
this.scheduleNextCheck();
|
if (isInitialCheck) {
|
||||||
return;
|
const [provisionedModels, livenessTest] = await Promise.all([
|
||||||
}
|
this.getProvisionedModels(key),
|
||||||
|
this.testLiveness(key),
|
||||||
this.log.debug({ key: key.hash }, "Checking key...");
|
this.maybeCreateOrganizationClones(key),
|
||||||
let isInitialCheck = !key.lastChecked;
|
]);
|
||||||
try {
|
const updates = {
|
||||||
// We only need to check for provisioned models on the initial check.
|
modelFamilies: provisionedModels,
|
||||||
if (isInitialCheck) {
|
isTrial: livenessTest.rateLimit <= 250,
|
||||||
const [provisionedModels, livenessTest] = await Promise.all([
|
};
|
||||||
this.getProvisionedModels(key),
|
this.updateKey(key.hash, updates);
|
||||||
this.testLiveness(key),
|
} else {
|
||||||
this.maybeCreateOrganizationClones(key),
|
// No updates needed as models and trial status generally don't change.
|
||||||
]);
|
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
||||||
const updates = {
|
|
||||||
modelFamilies: provisionedModels,
|
|
||||||
isTrial: livenessTest.rateLimit <= 250,
|
|
||||||
};
|
|
||||||
this.updateKey(key.hash, updates);
|
|
||||||
} else {
|
|
||||||
// No updates needed as models and trial status generally don't change.
|
|
||||||
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
|
||||||
this.updateKey(key.hash, {});
|
|
||||||
}
|
|
||||||
this.log.info(
|
|
||||||
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
|
|
||||||
"Key check complete."
|
|
||||||
);
|
|
||||||
} catch (error) {
|
|
||||||
// touch the key so we don't check it again for a while
|
|
||||||
this.updateKey(key.hash, {});
|
this.updateKey(key.hash, {});
|
||||||
this.handleAxiosError(key, error as AxiosError);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.lastCheck = Date.now();
|
|
||||||
// Only enqueue the next check if this wasn't a startup check, since those
|
|
||||||
// are batched together elsewhere.
|
|
||||||
if (!isInitialCheck) {
|
|
||||||
this.log.info(
|
|
||||||
{ key: key.hash },
|
|
||||||
"Recurring keychecks are disabled, no-op."
|
|
||||||
);
|
|
||||||
// this.scheduleNextCheck();
|
|
||||||
}
|
}
|
||||||
|
this.log.info(
|
||||||
|
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
|
||||||
|
"Checked key."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getProvisionedModels(
|
private async getProvisionedModels(
|
||||||
|
@ -138,6 +114,17 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
||||||
.filter(({ is_default }) => !is_default)
|
.filter(({ is_default }) => !is_default)
|
||||||
.map(({ id }) => id);
|
.map(({ id }) => id);
|
||||||
this.cloneKey(key.hash, ids);
|
this.cloneKey(key.hash, ids);
|
||||||
|
|
||||||
|
// It's possible that the keychecker may be stopped if all non-cloned keys
|
||||||
|
// happened to be unusable, in which case this clnoe will never be checked
|
||||||
|
// unless we restart the keychecker.
|
||||||
|
if (!this.timeout) {
|
||||||
|
this.log.warn(
|
||||||
|
{ parent: key.hash },
|
||||||
|
"Restarting key checker to check cloned keys."
|
||||||
|
);
|
||||||
|
this.scheduleNextCheck();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected handleAxiosError(key: OpenAIKey, error: AxiosError) {
|
protected handleAxiosError(key: OpenAIKey, error: AxiosError) {
|
||||||
|
|
|
@ -217,17 +217,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||||
return a.lastUsed - b.lastUsed;
|
return a.lastUsed - b.lastUsed;
|
||||||
});
|
});
|
||||||
|
|
||||||
// logger.debug(
|
|
||||||
// {
|
|
||||||
// byPriority: keysByPriority.map((k) => ({
|
|
||||||
// hash: k.hash,
|
|
||||||
// isRateLimited: now - k.rateLimitedAt < rateLimitThreshold,
|
|
||||||
// modelFamilies: k.modelFamilies,
|
|
||||||
// })),
|
|
||||||
// },
|
|
||||||
// "Keys sorted by priority"
|
|
||||||
// );
|
|
||||||
|
|
||||||
const selectedKey = keysByPriority[0];
|
const selectedKey = keysByPriority[0];
|
||||||
selectedKey.lastUsed = now;
|
selectedKey.lastUsed = now;
|
||||||
this.throttle(selectedKey.hash);
|
this.throttle(selectedKey.hash);
|
||||||
|
|
|
@ -2,15 +2,25 @@
|
||||||
|
|
||||||
import pino from "pino";
|
import pino from "pino";
|
||||||
|
|
||||||
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e";
|
export type OpenAIModelFamily =
|
||||||
|
| "turbo"
|
||||||
|
| "gpt4"
|
||||||
|
| "gpt4-32k"
|
||||||
|
| "gpt4-turbo"
|
||||||
|
| "dall-e";
|
||||||
export type AnthropicModelFamily = "claude";
|
export type AnthropicModelFamily = "claude";
|
||||||
export type GooglePalmModelFamily = "bison";
|
export type GooglePalmModelFamily = "bison";
|
||||||
export type AwsBedrockModelFamily = "aws-claude";
|
export type AwsBedrockModelFamily = "aws-claude";
|
||||||
|
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||||
|
OpenAIModelFamily,
|
||||||
|
"dall-e"
|
||||||
|
>}`;
|
||||||
export type ModelFamily =
|
export type ModelFamily =
|
||||||
| OpenAIModelFamily
|
| OpenAIModelFamily
|
||||||
| AnthropicModelFamily
|
| AnthropicModelFamily
|
||||||
| GooglePalmModelFamily
|
| GooglePalmModelFamily
|
||||||
| AwsBedrockModelFamily;
|
| AwsBedrockModelFamily
|
||||||
|
| AzureOpenAIModelFamily;
|
||||||
|
|
||||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||||
|
@ -23,6 +33,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||||
"claude",
|
"claude",
|
||||||
"bison",
|
"bison",
|
||||||
"aws-claude",
|
"aws-claude",
|
||||||
|
"azure-turbo",
|
||||||
|
"azure-gpt4",
|
||||||
|
"azure-gpt4-32k",
|
||||||
|
"azure-gpt4-turbo",
|
||||||
] as const);
|
] as const);
|
||||||
|
|
||||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||||
|
@ -64,6 +78,24 @@ export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
||||||
return "aws-claude";
|
return "aws-claude";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getAzureOpenAIModelFamily(
|
||||||
|
model: string,
|
||||||
|
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
|
||||||
|
): AzureOpenAIModelFamily {
|
||||||
|
// Azure model names omit periods. addAzureKey also prepends "azure-" to the
|
||||||
|
// model name to route the request the correct keyprovider, so we need to
|
||||||
|
// remove that as well.
|
||||||
|
const modified = model
|
||||||
|
.replace("gpt-35-turbo", "gpt-3.5-turbo")
|
||||||
|
.replace("azure-", "");
|
||||||
|
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
|
||||||
|
if (modified.match(regex)) {
|
||||||
|
return `azure-${family}` as AzureOpenAIModelFamily;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultFamily;
|
||||||
|
}
|
||||||
|
|
||||||
export function assertIsKnownModelFamily(
|
export function assertIsKnownModelFamily(
|
||||||
modelFamily: string
|
modelFamily: string
|
||||||
): asserts modelFamily is ModelFamily {
|
): asserts modelFamily is ModelFamily {
|
||||||
|
|
|
@ -12,6 +12,7 @@ import schedule from "node-schedule";
|
||||||
import { v4 as uuid } from "uuid";
|
import { v4 as uuid } from "uuid";
|
||||||
import { config, getFirebaseApp } from "../../config";
|
import { config, getFirebaseApp } from "../../config";
|
||||||
import {
|
import {
|
||||||
|
getAzureOpenAIModelFamily,
|
||||||
getClaudeModelFamily,
|
getClaudeModelFamily,
|
||||||
getGooglePalmModelFamily,
|
getGooglePalmModelFamily,
|
||||||
getOpenAIModelFamily,
|
getOpenAIModelFamily,
|
||||||
|
@ -34,6 +35,10 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
||||||
claude: 0,
|
claude: 0,
|
||||||
bison: 0,
|
bison: 0,
|
||||||
"aws-claude": 0,
|
"aws-claude": 0,
|
||||||
|
"azure-turbo": 0,
|
||||||
|
"azure-gpt4": 0,
|
||||||
|
"azure-gpt4-turbo": 0,
|
||||||
|
"azure-gpt4-32k": 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
const users: Map<string, User> = new Map();
|
const users: Map<string, User> = new Map();
|
||||||
|
@ -382,6 +387,9 @@ function getModelFamilyForQuotaUsage(
|
||||||
model: string,
|
model: string,
|
||||||
api: APIFormat
|
api: APIFormat
|
||||||
): ModelFamily {
|
): ModelFamily {
|
||||||
|
// TODO: this seems incorrect
|
||||||
|
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
|
||||||
|
|
||||||
switch (api) {
|
switch (api) {
|
||||||
case "openai":
|
case "openai":
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
|
|
Loading…
Reference in New Issue