This commit is contained in:
khanon 2023-12-04 04:21:18 +00:00
parent cd1b9d0e0c
commit fbdea30264
31 changed files with 1237 additions and 216 deletions

View File

@ -34,10 +34,10 @@
# Which model types users are allowed to access.
# The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
# generation, uncomment the line below and add 'dall-e' to the list.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com
@ -114,6 +114,8 @@ OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS.
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
# With proxy_key gatekeeper, the password users must provide to access the API.
# PROXY_KEY=your-secret-key

1
.gitignore vendored
View File

@ -5,3 +5,4 @@
build
greeting.md
node_modules
http-client.private.env.json

View File

@ -0,0 +1,25 @@
# Configuring the proxy for Azure
The proxy supports Azure OpenAI Service via the `/proxy/azure/openai` endpoint. The process of setting it up is slightly different from regular OpenAI.
- [Setting keys](#setting-keys)
- [Model assignment](#model-assignment)
## Setting keys
Use the `AZURE_CREDENTIALS` environment variable to set the Azure API keys.
Like other APIs, you can provide multiple keys separated by commas. Each Azure key, however, is a set of values including the Resource Name, Deployment ID, and API key. These are separated by a colon (`:`).
For example:
```
AZURE_CREDENTIALS=contoso-ml:gpt4-8k:0123456789abcdef0123456789abcdef,northwind-corp:testdeployment:0123456789abcdef0123456789abcdef
```
## Model assignment
Note that each Azure deployment is assigned a model when you create it in the Microsoft Cognitive Services portal. If you want to use a different model, you'll need to create a new deployment, and therefore a new key to be added to the AZURE_CREDENTIALS environment variable. Each credential only grants access to one model.
### Supported model IDs
Users can send normal OpenAI model IDs to the proxy to invoke the corresponding models. For the most part they work the same with Azure. GPT-3.5 Turbo has an ID of "gpt-35-turbo" because Azure doesn't allow periods in model names, but the proxy should automatically convert this to the correct ID.
As noted above, you can only use model IDs for which a deployment has been created and added to the proxy.

9
http-client.env.json Normal file
View File

@ -0,0 +1,9 @@
{
"dev": {
"proxy-host": "http://localhost:7860",
"oai-key-1": "override in http-client.private.env.json",
"proxy-key": "override in http-client.private.env.json",
"azu-resource-name": "override in http-client.private.env.json",
"azu-deployment-id": "override in http-client.private.env.json"
}
}

View File

@ -0,0 +1,248 @@
# OAI Reverse Proxy
###
# @name OpenAI -- Chat Completions
POST https://api.openai.com/v1/chat/completions
Authorization: Bearer {{oai-key-1}}
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"max_tokens": 30,
"stream": false,
"messages": [
{
"role": "user",
"content": "This is a test prompt."
}
]
}
###
# @name OpenAI -- Text Completions
POST https://api.openai.com/v1/completions
Authorization: Bearer {{oai-key-1}}
Content-Type: application/json
{
"model": "gpt-3.5-turbo-instruct",
"max_tokens": 30,
"stream": false,
"prompt": "This is a test prompt where"
}
###
# @name OpenAI -- Create Embedding
POST https://api.openai.com/v1/embeddings
Authorization: Bearer {{oai-key-1}}
Content-Type: application/json
{
"model": "text-embedding-ada-002",
"input": "This is a test embedding input."
}
###
# @name OpenAI -- Get Organizations
GET https://api.openai.com/v1/organizations
Authorization: Bearer {{oai-key-1}}
###
# @name OpenAI -- Get Models
GET https://api.openai.com/v1/models
Authorization: Bearer {{oai-key-1}}
###
# @name Azure OpenAI -- Chat Completions
POST https://{{azu-resource-name}}.openai.azure.com/openai/deployments/{{azu-deployment-id}}/chat/completions?api-version=2023-09-01-preview
api-key: {{azu-key-1}}
Content-Type: application/json
{
"max_tokens": 1,
"stream": false,
"messages": [
{
"role": "user",
"content": "This is a test prompt."
}
]
}
###
# @name Proxy / OpenAI -- Get Models
GET {{proxy-host}}/proxy/openai/v1/models
Authorization: Bearer {{proxy-key}}
###
# @name Proxy / OpenAI -- Native Chat Completions
POST {{proxy-host}}/proxy/openai/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"max_tokens": 20,
"stream": true,
"temperature": 1,
"seed": 123,
"messages": [
{
"role": "user",
"content": "phrase one"
}
]
}
###
# @name Proxy / OpenAI -- Native Text Completions
POST {{proxy-host}}/proxy/openai/v1/turbo-instruct/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-3.5-turbo-instruct",
"max_tokens": 20,
"temperature": 0,
"prompt": "Genshin Impact is a game about",
"stream": false
}
###
# @name Proxy / OpenAI -- Chat-to-Text API Translation
# Accepts a chat completion request and reformats it to work with the text completion API. `model` is ignored.
POST {{proxy-host}}/proxy/openai/turbo-instruct/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-4",
"max_tokens": 20,
"stream": true,
"messages": [
{
"role": "user",
"content": "What is the name of the fourth president of the united states?"
},
{
"role": "assistant",
"content": "That would be George Washington."
},
{
"role": "user",
"content": "I don't think that's right..."
}
]
}
###
# @name Proxy / OpenAI -- Create Embedding
POST {{proxy-host}}/proxy/openai/embeddings
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "text-embedding-ada-002",
"input": "This is a test embedding input."
}
###
# @name Proxy / Anthropic -- Native Completion (old API)
POST {{proxy-host}}/proxy/anthropic/v1/complete
Authorization: Bearer {{proxy-key}}
anthropic-version: 2023-01-01
Content-Type: application/json
{
"model": "claude-v1.3",
"max_tokens_to_sample": 20,
"temperature": 0.2,
"stream": true,
"prompt": "What is genshin impact\n\n:Assistant:"
}
###
# @name Proxy / Anthropic -- Native Completion (2023-06-01 API)
POST {{proxy-host}}/proxy/anthropic/v1/complete
Authorization: Bearer {{proxy-key}}
anthropic-version: 2023-06-01
Content-Type: application/json
{
"model": "claude-v1.3",
"max_tokens_to_sample": 20,
"temperature": 0.2,
"stream": true,
"prompt": "What is genshin impact\n\n:Assistant:"
}
###
# @name Proxy / Anthropic -- OpenAI-to-Anthropic API Translation
POST {{proxy-host}}/proxy/anthropic/v1/chat/completions
Authorization: Bearer {{proxy-key}}
#anthropic-version: 2023-06-01
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"max_tokens": 20,
"stream": false,
"temperature": 0,
"messages": [
{
"role": "user",
"content": "What is genshin impact"
}
]
}
###
# @name Proxy / AWS Claude -- Native Completion
POST {{proxy-host}}/proxy/aws/claude/v1/complete
Authorization: Bearer {{proxy-key}}
anthropic-version: 2023-01-01
Content-Type: application/json
{
"model": "claude-v2",
"max_tokens_to_sample": 10,
"temperature": 0,
"stream": true,
"prompt": "What is genshin impact\n\n:Assistant:"
}
###
# @name Proxy / AWS Claude -- OpenAI-to-Anthropic API Translation
POST {{proxy-host}}/proxy/aws/claude/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"max_tokens": 50,
"stream": true,
"messages": [
{
"role": "user",
"content": "What is genshin impact?"
}
]
}
###
# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation
POST {{proxy-host}}/proxy/google-palm/v1/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-4",
"max_tokens": 42,
"messages": [
{
"role": "user",
"content": "Hi what is the name of the fourth president of the united states?"
}
]
}

View File

@ -0,0 +1,40 @@
$NumThreads = 10
$runspacePool = [runspacefactory]::CreateRunspacePool(1, $NumThreads)
$runspacePool.Open()
$runspaces = @()
$headers = @{
"Authorization" = "Bearer test"
"anthropic-version" = "2023-01-01"
"Content-Type" = "application/json"
}
$payload = @{
model = "claude-v2"
max_tokens_to_sample = 40
temperature = 0
stream = $true
prompt = "Test prompt, please reply with lorem ipsum`n`n:Assistant:"
} | ConvertTo-Json
for ($i = 1; $i -le $NumThreads; $i++) {
Write-Host "Starting thread $i"
$runspace = [powershell]::Create()
$runspace.AddScript({
param($i, $headers, $payload)
$response = Invoke-WebRequest -Uri "http://localhost:7860/proxy/aws/claude/v1/complete" -Method Post -Headers $headers -Body $payload
Write-Host "Response from server: $($response.StatusCode)"
}).AddArgument($i).AddArgument($headers).AddArgument($payload)
$runspace.RunspacePool = $runspacePool
$runspaces += [PSCustomObject]@{ Pipe = $runspace; Status = $runspace.BeginInvoke() }
}
$runspaces | ForEach-Object {
$_.Pipe.EndInvoke($_.Status)
$_.Pipe.Dispose()
}
$runspacePool.Close()
$runspacePool.Dispose()

View File

@ -33,6 +33,17 @@ type Config = {
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
*/
awsCredentials?: string;
/**
* Comma-delimited list of Azure OpenAI credentials. Each credential item
* should be a colon-delimited list of Azure resource name, deployment ID, and
* API key.
*
* The resource name is the subdomain in your Azure OpenAI deployment's URL,
* e.g. `https://resource-name.openai.azure.com
*
* @example `AZURE_CREDENTIALS=resource_name_1:deployment_id_1:api_key_1,resource_name_2:deployment_id_2:api_key_2`
*/
azureCredentials?: string;
/**
* The proxy key to require for requests. Only applicable if the user
* management mode is set to 'proxy_key', and required if so.
@ -188,6 +199,7 @@ export const config: Config = {
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
@ -219,6 +231,10 @@ export const config: Config = {
"claude",
"bison",
"aws-claude",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-turbo",
"azure-gpt4-32k",
]),
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
rejectMessage: getEnvWithDefault(
@ -352,6 +368,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"anthropicKey",
"googlePalmKey",
"awsCredentials",
"azureCredentials",
"proxyKey",
"adminKey",
"rejectPhrases",
@ -369,6 +386,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"useInsecureCookies",
"staticServiceInfo",
"checkKeys",
"allowedModelFamilies",
];
const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>;
@ -417,6 +435,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
"ANTHROPIC_KEY",
"GOOGLE_PALM_KEY",
"AWS_CREDENTIALS",
"AZURE_CREDENTIALS",
].includes(String(env))
) {
return value as unknown as T;

View File

@ -1,3 +1,4 @@
/** This whole module really sucks */
import fs from "fs";
import { Request, Response } from "express";
import showdown from "showdown";
@ -5,11 +6,16 @@ import { config, listConfig } from "./config";
import {
AnthropicKey,
AwsBedrockKey,
AzureOpenAIKey,
GooglePalmKey,
keyPool,
OpenAIKey,
} from "./shared/key-management";
import { ModelFamily, OpenAIModelFamily } from "./shared/models";
import {
AzureOpenAIModelFamily,
ModelFamily,
OpenAIModelFamily,
} from "./shared/models";
import { getUniqueIps } from "./proxy/rate-limit";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
@ -23,6 +29,8 @@ let infoPageLastUpdated = 0;
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
k.service === "openai";
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic";
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
@ -48,6 +56,7 @@ type ServiceAggregates = {
anthropicKeys?: number;
palmKeys?: number;
awsKeys?: number;
azureKeys?: number;
proompts: number;
tokens: number;
tokenCost: number;
@ -62,17 +71,15 @@ const serviceStats = new Map<keyof ServiceAggregates, number>();
export const handleInfoPage = (req: Request, res: Response) => {
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
res.send(infoPageHtml);
return;
return res.send(infoPageHtml);
}
// Sometimes huggingface doesn't send the host header and makes us guess.
const baseUrl =
process.env.SPACE_ID && !req.get("host")?.includes("hf.space")
? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID)
: req.protocol + "://" + req.get("host");
infoPageHtml = buildInfoPageHtml(baseUrl);
infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy");
infoPageLastUpdated = Date.now();
res.send(infoPageHtml);
@ -95,6 +102,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
const palmKeys = serviceStats.get("palmKeys") || 0;
const awsKeys = serviceStats.get("awsKeys") || 0;
const azureKeys = serviceStats.get("azureKeys") || 0;
const proompts = serviceStats.get("proompts") || 0;
const tokens = serviceStats.get("tokens") || 0;
const tokenCost = serviceStats.get("tokenCost") || 0;
@ -102,16 +110,15 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
const allowDalle = config.allowedModelFamilies.includes("dall-e");
const endpoints = {
...(openaiKeys ? { openai: baseUrl + "/proxy/openai" } : {}),
...(openaiKeys
? { ["openai2"]: baseUrl + "/proxy/openai/turbo-instruct" }
: {}),
...(openaiKeys ? { openai: baseUrl + "/openai" } : {}),
...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}),
...(openaiKeys && allowDalle
? { ["openai-image"]: baseUrl + "/proxy/openai-image" }
? { ["openai-image"]: baseUrl + "/openai-image" }
: {}),
...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}),
...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}),
...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}),
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
};
const stats = {
@ -120,13 +127,17 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
};
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys };
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
for (const key of Object.keys(keyInfo)) {
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
}
const providerInfo = {
...(openaiKeys ? getOpenAIInfo() : {}),
...(anthropicKeys ? getAnthropicInfo() : {}),
...(palmKeys ? { "palm-bison": getPalmInfo() } : {}),
...(awsKeys ? { "aws-claude": getAwsInfo() } : {}),
...(palmKeys ? getPalmInfo() : {}),
...(awsKeys ? getAwsInfo() : {}),
...(azureKeys ? getAzureInfo() : {}),
};
if (hideFullInfo) {
@ -188,6 +199,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
let sumTokens = 0;
let sumCost = 0;
@ -201,17 +213,26 @@ function addKeyToAggregates(k: KeyPoolKey) {
Boolean(k.lastChecked) ? 0 : 1
);
// Technically this would not account for keys that have tokens recorded
// on models they aren't provisioned for, but that would be strange
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
});
break;
case "azure":
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
});
break;
case "anthropic": {
@ -381,11 +402,13 @@ function getPalmInfo() {
const cost = getTokenCostUsd("bison", tokens);
return {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: bisonInfo.active,
revokedKeys: bisonInfo.revoked,
proomptersInQueue: bisonInfo.queued,
estimatedQueueTime: bisonInfo.queueTime,
bison: {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: bisonInfo.active,
revokedKeys: bisonInfo.revoked,
proomptersInQueue: bisonInfo.queued,
estimatedQueueTime: bisonInfo.queueTime,
},
};
}
@ -408,15 +431,59 @@ function getAwsInfo() {
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
return {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: awsInfo.active,
revokedKeys: awsInfo.revoked,
proomptersInQueue: awsInfo.queued,
estimatedQueueTime: awsInfo.queueTime,
...(logged > 0 ? { privacy: logMsg } : {}),
"aws-claude": {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: awsInfo.active,
revokedKeys: awsInfo.revoked,
proomptersInQueue: awsInfo.queued,
estimatedQueueTime: awsInfo.queueTime,
...(logged > 0 ? { privacy: logMsg } : {}),
},
};
}
function getAzureInfo() {
const azureFamilies = [
"azure-turbo",
"azure-gpt4",
"azure-gpt4-turbo",
"azure-gpt4-32k",
] as const;
const azureInfo: {
[modelFamily in AzureOpenAIModelFamily]?: {
usage?: string;
activeKeys: number;
revokedKeys?: number;
proomptersInQueue?: number;
estimatedQueueTime?: string;
};
} = {};
for (const family of azureFamilies) {
const familyAllowed = config.allowedModelFamilies.includes(family);
const activeKeys = modelStats.get(`${family}__active`) || 0;
if (!familyAllowed || activeKeys === 0) continue;
azureInfo[family] = {
activeKeys,
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
};
const queue = getQueueInformation(family);
azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue;
azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime;
const tokens = modelStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString(
cost
)}`;
}
return azureInfo;
}
const customGreeting = fs.existsSync("greeting.md")
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
: "";
@ -430,10 +497,10 @@ function buildInfoPageHeader(converter: showdown.Converter, title: string) {
let infoBody = `<!-- Header for Showdown's parser, don't remove this line -->
# ${title}`;
if (config.promptLogging) {
infoBody += `\n## Prompt logging is enabled!
The server operator has enabled prompt logging. The prompts you send to this proxy and the AI responses you receive may be saved.
infoBody += `\n## Prompt Logging Enabled
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
Logs are anonymous and do not contain IP addresses or timestamps. [You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/prompt-logging/index.ts).
[You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/shared/prompt-logging/index.ts).
**If you are uncomfortable with this, don't send prompts to this proxy!**`;
}
@ -570,8 +637,6 @@ function escapeHtml(unsafe: string) {
}
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
// Huggingface broke their amazon elb config and no longer sends the
// x-forwarded-host header. This is a workaround.
try {
const [username, spacename] = spaceId.split("/");
return `https://${username}-${spacename.replace(/_/g, "-")}.hf.space`;

View File

@ -11,7 +11,7 @@ import {
createPreprocessorMiddleware,
stripHeaders,
signAwsRequest,
finalizeAwsRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
blockZoomerOrigins,
} from "./middleware/request";
@ -30,7 +30,11 @@ const getModelsResponse = () => {
if (!config.awsCredentials) return { object: "list", data: [] };
const variants = ["anthropic.claude-v1", "anthropic.claude-v2"];
const variants = [
"anthropic.claude-v1",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
];
const models = variants.map((id) => ({
id,
@ -134,7 +138,7 @@ const awsProxy = createQueueMiddleware({
applyQuotaLimits,
blockZoomerOrigins,
stripHeaders,
finalizeAwsRequest,
finalizeSignedRequest,
],
}),
proxyRes: createOnProxyResHandler([awsResponseHandler]),
@ -183,7 +187,7 @@ function maybeReassignModel(req: Request) {
req.body.model = "anthropic.claude-v1";
} else {
// User's client requested v2 or possibly some OpenAI model, default to v2
req.body.model = "anthropic.claude-v2";
req.body.model = "anthropic.claude-v2:1";
}
// TODO: Handle claude-instant
}

140
src/proxy/azure.ts Normal file
View File

@ -0,0 +1,140 @@
import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
import {
ModelFamily,
AzureOpenAIModelFamily,
getAzureOpenAIModelFamily,
} from "../shared/models";
import { logger } from "../logger";
import { KNOWN_OPENAI_MODELS } from "./openai";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
applyQuotaLimits,
blockZoomerOrigins,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
limitCompletions,
stripHeaders,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { addAzureKey } from "./middleware/request/add-azure-key";
let modelsCache: any = null;
let modelsCacheTime = 0;
function getModelsResponse() {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
let available = new Set<AzureOpenAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "azure") continue;
key.modelFamilies.forEach((family) =>
available.add(family as AzureOpenAIModelFamily)
);
}
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
available = new Set([...available].filter((x) => allowed.has(x)));
const models = KNOWN_OPENAI_MODELS.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "azure",
permission: [
{
id: "modelperm-" + id,
object: "model_permission",
created: new Date().getTime(),
organization: "*",
group: null,
is_blocking: false,
},
],
root: id,
parent: null,
})).filter((model) => available.has(getAzureOpenAIModelFamily(model.id)));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
}
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
const azureOpenAIProxy = createQueueMiddleware({
beforeProxy: addAzureKey,
proxyMiddleware: createProxyMiddleware({
target: "will be set by router",
router: (req) => {
if (!req.signedRequest) throw new Error("signedRequest not set");
const { hostname, path } = req.signedRequest;
return `https://${hostname}${path}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
limitCompletions,
blockZoomerOrigins,
stripHeaders,
finalizeSignedRequest,
],
}),
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
error: handleProxyError,
},
}),
});
const azureOpenAIRouter = Router();
azureOpenAIRouter.get("/v1/models", handleModelRequest);
azureOpenAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "openai",
outApi: "openai",
service: "azure",
}),
azureOpenAIProxy
);
export const azure = azureOpenAIRouter;

View File

@ -59,7 +59,7 @@ export function writeErrorResponse(
res.write(`data: [DONE]\n\n`);
res.end();
} else {
if (req.tokenizerInfo && errorPayload.error) {
if (req.tokenizerInfo && typeof errorPayload.error === "object") {
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
}
res.status(statusCode).json(errorPayload);

View File

@ -0,0 +1,50 @@
import { AzureOpenAIKey, keyPool } from "../../../shared/key-management";
import { RequestPreprocessor } from ".";
export const addAzureKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
const serviceValid = req.service === "azure";
if (!apisValid || !serviceValid) {
throw new Error("addAzureKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const model = req.body.model.startsWith("azure-")
? req.body.model
: `azure-${req.body.model}`;
req.key = keyPool.get(model);
req.body.model = model;
req.log.info(
{ key: req.key.hash, model },
"Assigned Azure OpenAI key to request"
);
const cred = req.key as AzureOpenAIKey;
const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred);
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: `${resourceName}.openai.azure.com`,
path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`,
headers: {
["host"]: `${resourceName}.openai.azure.com`,
["content-type"]: "application/json",
["api-key"]: apiKey,
},
body: JSON.stringify(req.body),
};
};
function getCredentialsFromKey(key: AzureOpenAIKey) {
const [resourceName, deploymentId, apiKey] = key.key.split(":");
if (!resourceName || !deploymentId || !apiKey) {
throw new Error("Assigned Azure OpenAI key is not in the correct format.");
}
return { resourceName, deploymentId, apiKey };
}

View File

@ -80,6 +80,10 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
`?key=${assignedKey.key}`
);
break;
case "azure":
const azureKey = assignedKey.key;
proxyReq.setHeader("api-key", azureKey);
break;
case "aws":
throw new Error(
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."

View File

@ -1,11 +1,11 @@
import type { ProxyRequestMiddleware } from ".";
/**
* For AWS requests, the body is signed earlier in the request pipeline, before
* the proxy middleware. This function just assigns the path and headers to the
* proxy request.
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
* before the proxy middleware. This function just assigns the path and headers
* to the proxy request.
*/
export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => {
export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => {
if (!req.signedRequest) {
throw new Error("Expected req.signedRequest to be set");
}

View File

@ -22,7 +22,7 @@ export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
export { addAnthropicPreamble } from "./add-anthropic-preamble";
export { blockZoomerOrigins } from "./block-zoomer-origins";
export { finalizeBody } from "./finalize-body";
export { finalizeAwsRequest } from "./finalize-aws-request";
export { finalizeSignedRequest } from "./finalize-signed-request";
export { limitCompletions } from "./limit-completions";
export { stripHeaders } from "./strip-headers";

View File

@ -289,15 +289,17 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
switch (service) {
case "openai":
case "google-palm":
if (errorPayload.error?.code === "content_policy_violation") {
errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`;
case "azure":
const filteredCodes = ["content_policy_violation", "content_filter"];
if (filteredCodes.includes(errorPayload.error?.code)) {
errorPayload.proxy_note = `Request was filtered by the upstream API's content moderation system. Modify your prompt and try again.`;
refundLastAttempt(req);
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
// For some reason, some models return this 400 error instead of the
// same 429 billing error that other models return.
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
} else {
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
}
break;
case "anthropic":
@ -342,7 +344,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
handleAwsRateLimitError(req, errorPayload);
break;
case "google-palm":
throw new Error("Rate limit handling not implemented for PaLM");
case "azure":
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
break;
default:
assertNever(service);
}
@ -369,6 +373,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
break;
case "azure":
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
break;
default:
assertNever(service);
}

View File

@ -28,6 +28,7 @@ type SSEMessageTransformerOptions = TransformOptions & {
export class SSEMessageTransformer extends Transform {
private lastPosition: number;
private msgCount: number;
private readonly inputFormat: APIFormat;
private readonly transformFn: StreamingCompletionTransformer;
private readonly log;
private readonly fallbackId: string;
@ -42,6 +43,7 @@ export class SSEMessageTransformer extends Transform {
options.inputFormat,
options.inputApiVersion
);
this.inputFormat = options.inputFormat;
this.fallbackId = options.requestId;
this.fallbackModel = options.requestedModel;
this.log.debug(
@ -67,6 +69,17 @@ export class SSEMessageTransformer extends Transform {
});
this.lastPosition = newPosition;
// Special case for Azure OpenAI, which is 99% the same as OpenAI but
// sometimes emits an extra event at the beginning of the stream with the
// content moderation system's response to the prompt. A lot of frontends
// don't expect this and neither does our event aggregator so we drop it.
if (this.inputFormat === "openai" && this.msgCount <= 1) {
if (originalMessage.includes("prompt_filter_results")) {
this.log.debug("Dropping Azure OpenAI content moderation SSE event");
return callback();
}
}
this.emit("originalMessage", originalMessage);
// Some events may not be transformed, e.g. ping events

View File

@ -24,7 +24,7 @@ import {
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
// https://platform.openai.com/docs/models/overview
const KNOWN_MODELS = [
export const KNOWN_OPENAI_MODELS = [
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
@ -46,7 +46,7 @@ const KNOWN_MODELS = [
let modelsCache: any = null;
let modelsCacheTime = 0;
export function generateModelList(models = KNOWN_MODELS) {
export function generateModelList(models = KNOWN_OPENAI_MODELS) {
let available = new Set<OpenAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "openai") continue;

View File

@ -26,6 +26,7 @@ import { assertNever } from "../shared/utils";
import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
import { handleProxyError } from "./middleware/common";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
@ -34,7 +35,7 @@ const log = logger.child({ module: "request-queue" });
const AGNAI_CONCURRENCY_LIMIT = 5;
/** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = 1;
const MIN_HEARTBEAT_SIZE = 512;
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
const MAX_HEARTBEAT_SIZE =
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
const HEARTBEAT_INTERVAL =
@ -358,12 +359,16 @@ export function createQueueMiddleware({
return (req, res, next) => {
req.proceed = async () => {
if (beforeProxy) {
// Hack to let us run asynchronous middleware before the
// http-proxy-middleware handler. This is used to sign AWS requests
// before they are proxied, as the signing is asynchronous.
// Unlike RequestPreprocessors, this runs every time the request is
// dequeued, not just the first time.
await beforeProxy(req);
try {
// Hack to let us run asynchronous middleware before the
// http-proxy-middleware handler. This is used to sign AWS requests
// before they are proxied, as the signing is asynchronous.
// Unlike RequestPreprocessors, this runs every time the request is
// dequeued, not just the first time.
await beforeProxy(req);
} catch (err) {
return handleProxyError(err, req, res);
}
}
proxyMiddleware(req, res, next);
};

View File

@ -6,6 +6,7 @@ import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic";
import { googlePalm } from "./palm";
import { aws } from "./aws";
import { azure } from "./azure";
const proxyRouter = express.Router();
proxyRouter.use((req, _res, next) => {
@ -32,6 +33,7 @@ proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-palm", addV1, googlePalm);
proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage.
proxyRouter.get("*", (req, res, next) => {
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");

View File

@ -26,46 +26,23 @@ type AnthropicAPIError = {
type UpdateFn = typeof AnthropicKeyProvider.prototype.update;
export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
private readonly updateKey: UpdateFn;
constructor(keys: AnthropicKey[], updateKey: UpdateFn) {
super(keys, {
service: "anthropic",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
updateKey,
});
this.updateKey = updateKey;
}
protected async checkKey(key: AnthropicKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
}
this.log.debug({ key: key.hash }, "Checking key...");
let isInitialCheck = !key.lastChecked;
try {
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
const updates = { isPozzed: pozzed };
this.updateKey(key.hash, updates);
this.log.info(
{ key: key.hash, models: key.modelFamilies },
"Key check complete."
);
} catch (error) {
// touch the key so we don't check it again for a while
this.updateKey(key.hash, {});
this.handleAxiosError(key, error as AxiosError);
}
this.lastCheck = Date.now();
// Only enqueue the next check if this wasn't a startup check, since those
// are batched together elsewhere.
if (!isInitialCheck) {
this.scheduleNextCheck();
}
protected async testKeyOrFail(key: AnthropicKey) {
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
const updates = { isPozzed: pozzed };
this.updateKey(key.hash, updates);
this.log.info(
{ key: key.hash, models: key.modelFamilies },
"Checked key."
);
}
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
@ -84,6 +61,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
{ key: key.hash, error: error.message },
"Key is rate limited. Rechecking in 10 seconds."
);
0;
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
this.updateKey(key.hash, { lastChecked: next });
break;

View File

@ -32,58 +32,36 @@ type GetLoggingConfigResponse = {
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
private readonly updateKey: UpdateFn;
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
super(keys, {
service: "aws",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
updateKey,
});
this.updateKey = updateKey;
}
protected async checkKey(key: AwsBedrockKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
protected async testKeyOrFail(key: AwsBedrockKey) {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
const modelChecks: Promise<unknown>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
}
this.log.debug({ key: key.hash }, "Checking key...");
let isInitialCheck = !key.lastChecked;
try {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
const modelChecks: Promise<unknown>[] = [];
if (isInitialCheck) {
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
}
await Promise.all(modelChecks);
await this.checkLoggingConfiguration(key);
await Promise.all(modelChecks);
await this.checkLoggingConfiguration(key);
this.log.info(
{
key: key.hash,
models: key.modelFamilies,
logged: key.awsLoggingStatus,
},
"Key check complete."
);
} catch (error) {
this.handleAxiosError(key, error as AxiosError);
}
this.updateKey(key.hash, {});
this.lastCheck = Date.now();
// Only enqueue the next check if this wasn't a startup check, since those
// are batched together elsewhere.
if (!isInitialCheck) {
this.scheduleNextCheck();
}
this.log.info(
{
key: key.hash,
models: key.modelFamilies,
logged: key.awsLoggingStatus,
},
"Checked key."
);
}
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {

View File

@ -0,0 +1,149 @@
import axios, { AxiosError } from "axios";
import { KeyCheckerBase } from "../key-checker-base";
import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
import { getAzureOpenAIModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com";
const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) =>
`https://${AZURE_HOST.replace(
"%RESOURCE_NAME%",
resourceName
)}/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`;
type AzureError = {
error: {
message: string;
type: string | null;
param: string;
code: string;
status: number;
};
};
type UpdateFn = typeof AzureOpenAIKeyProvider.prototype.update;
export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
constructor(keys: AzureOpenAIKey[], updateKey: UpdateFn) {
super(keys, {
service: "azure",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
protected async testKeyOrFail(key: AzureOpenAIKey) {
const model = await this.testModel(key);
this.log.info(
{ key: key.hash, deploymentModel: model },
"Checked key."
);
this.updateKey(key.hash, { modelFamilies: [model] });
}
// provided api-key header isn't valid (401)
// {
// "error": {
// "code": "401",
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
// }
// }
// api key correct but deployment id is wrong (404)
// {
// "error": {
// "code": "DeploymentNotFound",
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
// }
// }
// resource name is wrong (node will throw ENOTFOUND)
// rate limited (429)
// TODO: try to reproduce this
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
const data = error.response.data;
const status = data.error.status;
const errorType = data.error.code || data.error.type;
switch (errorType) {
case "DeploymentNotFound":
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is revoked or deployment ID is incorrect. Disabling key."
);
return this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
});
case "401":
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is disabled or incorrect. Disabling key."
);
return this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
});
default:
this.log.error(
{ key: key.hash, errorType, error: error.response.data, status },
"Unknown Azure API error while checking key. Please report this."
);
return this.updateKey(key.hash, { lastChecked: Date.now() });
}
}
const { response, code } = error;
if (code === "ENOTFOUND") {
this.log.warn(
{ key: key.hash, error: error.message },
"Resource name is probably incorrect. Disabling key."
);
return this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
}
const { headers, status, data } = response ?? {};
this.log.error(
{ key: key.hash, status, headers, data, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
private async testModel(key: AzureOpenAIKey) {
const { apiKey, deploymentId, resourceName } =
AzureOpenAIKeyChecker.getCredentialsFromKey(key);
const url = POST_CHAT_COMPLETIONS(resourceName, deploymentId);
const testRequest = {
max_tokens: 1,
stream: false,
messages: [{ role: "user", content: "" }],
};
const { data } = await axios.post(url, testRequest, {
headers: { "Content-Type": "application/json", "api-key": apiKey },
});
return getAzureOpenAIModelFamily(data.model);
}
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
const data = error.response?.data as any;
return data?.error?.code || data?.error?.type;
}
static getCredentialsFromKey(key: AzureOpenAIKey) {
const [resourceName, deploymentId, apiKey] = key.key.split(":");
if (!resourceName || !deploymentId || !apiKey) {
throw new Error(
"Invalid Azure credential format. Refer to .env.example and ensure your credentials are in the format RESOURCE_NAME:DEPLOYMENT_ID:API_KEY with commas between each credential set."
);
}
return { resourceName, deploymentId, apiKey };
}
}

View File

@ -0,0 +1,212 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models";
import { OpenAIModel } from "../openai/provider";
import { AzureOpenAIKeyChecker } from "./checker";
import { AwsKeyChecker } from "../aws/checker";
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
type AzureOpenAIKeyUsage = {
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
};
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
readonly service: "azure";
readonly modelFamilies: AzureOpenAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
contentFiltering: boolean;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 4000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 250;
export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
readonly service = "azure";
private keys: AzureOpenAIKey[] = [];
private checker?: AzureOpenAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.azureCredentials;
if (!keyConfig) {
this.log.warn(
"AZURE_CREDENTIALS is not set. Azure OpenAI API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: AzureOpenAIKey = {
key,
service: this.service,
modelFamilies: ["azure-gpt4"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
contentFiltering: false,
hash: `azu-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"azure-turboTokens": 0,
"azure-gpt4Tokens": 0,
"azure-gpt4-32kTokens": 0,
"azure-gpt4-turboTokens": 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded Azure OpenAI keys.");
}
public init() {
if (config.checkKeys) {
this.checker = new AzureOpenAIKeyChecker(
this.keys,
this.update.bind(this)
);
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(model: AzureOpenAIModel) {
const neededFamily = getAzureOpenAIModelFamily(model);
const availableKeys = this.keys.filter(
(k) => !k.isDisabled && k.modelFamilies.includes(neededFamily)
);
if (availableKeys.length === 0) {
throw new Error(`No keys available for model family '${neededFamily}'.`);
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
public disable(key: AzureOpenAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<AzureOpenAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
}
// TODO: all of this shit is duplicate code
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve, so all we can do is wait a bit and try
* again. We will lock the key for 2 seconds after getting a 429 before
* retrying in order to give the other requests a chance to finish.
*/
public markRateLimited(keyHash: string) {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public recheck() {
this.keys.forEach(({ hash }) =>
this.update(hash, { lastChecked: 0, isDisabled: false })
);
}
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}

View File

@ -2,6 +2,7 @@ import { OpenAIModel } from "./openai/provider";
import { AnthropicModel } from "./anthropic/provider";
import { GooglePalmModel } from "./palm/provider";
import { AwsBedrockModel } from "./aws/provider";
import { AzureOpenAIModel } from "./azure/provider";
import { KeyPool } from "./key-pool";
import type { ModelFamily } from "../models";
@ -13,12 +14,18 @@ export type APIFormat =
| "openai-text"
| "openai-image";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
export type LLMService =
| "openai"
| "anthropic"
| "google-palm"
| "aws"
| "azure";
export type Model =
| OpenAIModel
| AnthropicModel
| GooglePalmModel
| AwsBedrockModel;
| AwsBedrockModel
| AzureOpenAIModel;
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
@ -72,3 +79,4 @@ export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider";
export { AwsBedrockKey } from "./aws/provider";
export { AzureOpenAIKey } from "./azure/provider";

View File

@ -3,14 +3,17 @@ import { logger } from "../../logger";
import { Key } from "./index";
import { AxiosError } from "axios";
type KeyCheckerOptions = {
type KeyCheckerOptions<TKey extends Key = Key> = {
service: string;
keyCheckPeriod: number;
minCheckInterval: number;
}
recurringChecksEnabled?: boolean;
updateKey: (hash: string, props: Partial<TKey>) => void;
};
export abstract class KeyCheckerBase<TKey extends Key> {
protected readonly service: string;
protected readonly RECURRING_CHECKS_ENABLED: boolean;
/** Minimum time in between any two key checks. */
protected readonly MIN_CHECK_INTERVAL: number;
/**
@ -19,16 +22,19 @@ export abstract class KeyCheckerBase<TKey extends Key> {
* than this.
*/
protected readonly KEY_CHECK_PERIOD: number;
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
protected readonly keys: TKey[] = [];
protected log: pino.Logger;
protected timeout?: NodeJS.Timeout;
protected lastCheck = 0;
protected constructor(keys: TKey[], opts: KeyCheckerOptions) {
protected constructor(keys: TKey[], opts: KeyCheckerOptions<TKey>) {
const { service, keyCheckPeriod, minCheckInterval } = opts;
this.keys = keys;
this.KEY_CHECK_PERIOD = keyCheckPeriod;
this.MIN_CHECK_INTERVAL = minCheckInterval;
this.RECURRING_CHECKS_ENABLED = opts.recurringChecksEnabled ?? true;
this.updateKey = opts.updateKey;
this.service = service;
this.log = logger.child({ module: "key-checker", service });
}
@ -52,31 +58,34 @@ export abstract class KeyCheckerBase<TKey extends Key> {
* the minimum check interval.
*/
public scheduleNextCheck() {
// Gives each concurrent check a correlation ID to make logs less confusing.
const callId = Math.random().toString(36).slice(2, 8);
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
const checkLog = this.log.child({ callId, timeoutId });
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
const numEnabled = enabledKeys.length;
const numUnchecked = uncheckedKeys.length;
clearTimeout(this.timeout);
this.timeout = undefined;
if (enabledKeys.length === 0) {
checkLog.warn("All keys are disabled. Key checker stopping.");
if (!numEnabled) {
checkLog.warn("All keys are disabled. Stopping.");
return;
}
// Perform startup checks for any keys that haven't been checked yet.
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
if (uncheckedKeys.length > 0) {
const keysToCheck = uncheckedKeys.slice(0, 12);
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
if (numUnchecked > 0) {
const keycheckBatch = uncheckedKeys.slice(0, 12);
this.timeout = setTimeout(async () => {
try {
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
await Promise.all(keycheckBatch.map((key) => this.checkKey(key)));
} catch (error) {
this.log.error({ error }, "Error checking one or more keys.");
checkLog.error({ error }, "Error checking one or more keys.");
}
checkLog.info("Batch complete.");
this.scheduleNextCheck();
@ -84,11 +93,18 @@ export abstract class KeyCheckerBase<TKey extends Key> {
checkLog.info(
{
batch: keysToCheck.map((k) => k.hash),
remaining: uncheckedKeys.length - keysToCheck.length,
batch: keycheckBatch.map((k) => k.hash),
remaining: uncheckedKeys.length - keycheckBatch.length,
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
},
"Scheduled batch check."
"Scheduled batch of initial checks."
);
return;
}
if (!this.RECURRING_CHECKS_ENABLED) {
checkLog.info(
"Initial checks complete and recurring checks are disabled for this service. Stopping."
);
return;
}
@ -106,14 +122,35 @@ export abstract class KeyCheckerBase<TKey extends Key> {
);
const delay = nextCheck - Date.now();
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
this.timeout = setTimeout(
() => this.checkKey(oldestKey).then(() => this.scheduleNextCheck()),
delay
);
checkLog.debug(
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
"Scheduled single key check."
"Scheduled next recurring check."
);
}
protected abstract checkKey(key: TKey): Promise<void>;
public async checkKey(key: TKey): Promise<void> {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
}
this.log.debug({ key: key.hash }, "Checking key...");
try {
await this.testKeyOrFail(key);
} catch (error) {
this.updateKey(key.hash, {});
this.handleAxiosError(key, error as AxiosError);
}
this.lastCheck = Date.now();
}
protected abstract testKeyOrFail(key: TKey): Promise<void>;
protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
}
}

View File

@ -11,6 +11,7 @@ import { GooglePalmKeyProvider } from "./palm/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { ModelFamily } from "../models";
import { assertNever } from "../utils";
import { AzureOpenAIKeyProvider } from "./azure/provider";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@ -25,6 +26,7 @@ export class KeyPool {
this.keyProviders.push(new AnthropicKeyProvider());
this.keyProviders.push(new GooglePalmKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider());
}
public init() {
@ -124,6 +126,8 @@ export class KeyPool {
// AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
return "aws";
} else if (model.startsWith("azure")) {
return "azure";
}
throw new Error(`Unknown service for model '${model}'`);
}
@ -142,6 +146,11 @@ export class KeyPool {
return "google-palm";
case "aws-claude":
return "aws";
case "azure-turbo":
case "azure-gpt4":
case "azure-gpt4-32k":
case "azure-gpt4-turbo":
return "azure";
default:
assertNever(modelFamily);
}

View File

@ -27,65 +27,41 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
private readonly cloneKey: CloneFn;
private readonly updateKey: UpdateFn;
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
super(keys, {
service: "openai",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
this.cloneKey = cloneFn;
this.updateKey = updateKey;
}
protected async checkKey(key: OpenAIKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
}
this.log.debug({ key: key.hash }, "Checking key...");
let isInitialCheck = !key.lastChecked;
try {
// We only need to check for provisioned models on the initial check.
if (isInitialCheck) {
const [provisionedModels, livenessTest] = await Promise.all([
this.getProvisionedModels(key),
this.testLiveness(key),
this.maybeCreateOrganizationClones(key),
]);
const updates = {
modelFamilies: provisionedModels,
isTrial: livenessTest.rateLimit <= 250,
};
this.updateKey(key.hash, updates);
} else {
// No updates needed as models and trial status generally don't change.
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
this.updateKey(key.hash, {});
}
this.log.info(
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
"Key check complete."
);
} catch (error) {
// touch the key so we don't check it again for a while
protected async testKeyOrFail(key: OpenAIKey) {
// We only need to check for provisioned models on the initial check.
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
const [provisionedModels, livenessTest] = await Promise.all([
this.getProvisionedModels(key),
this.testLiveness(key),
this.maybeCreateOrganizationClones(key),
]);
const updates = {
modelFamilies: provisionedModels,
isTrial: livenessTest.rateLimit <= 250,
};
this.updateKey(key.hash, updates);
} else {
// No updates needed as models and trial status generally don't change.
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
this.updateKey(key.hash, {});
this.handleAxiosError(key, error as AxiosError);
}
this.lastCheck = Date.now();
// Only enqueue the next check if this wasn't a startup check, since those
// are batched together elsewhere.
if (!isInitialCheck) {
this.log.info(
{ key: key.hash },
"Recurring keychecks are disabled, no-op."
);
// this.scheduleNextCheck();
}
this.log.info(
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
"Checked key."
);
}
private async getProvisionedModels(
@ -138,6 +114,17 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
.filter(({ is_default }) => !is_default)
.map(({ id }) => id);
this.cloneKey(key.hash, ids);
// It's possible that the keychecker may be stopped if all non-cloned keys
// happened to be unusable, in which case this clnoe will never be checked
// unless we restart the keychecker.
if (!this.timeout) {
this.log.warn(
{ parent: key.hash },
"Restarting key checker to check cloned keys."
);
this.scheduleNextCheck();
}
}
protected handleAxiosError(key: OpenAIKey, error: AxiosError) {

View File

@ -217,17 +217,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
return a.lastUsed - b.lastUsed;
});
// logger.debug(
// {
// byPriority: keysByPriority.map((k) => ({
// hash: k.hash,
// isRateLimited: now - k.rateLimitedAt < rateLimitThreshold,
// modelFamilies: k.modelFamilies,
// })),
// },
// "Keys sorted by priority"
// );
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);

View File

@ -2,15 +2,25 @@
import pino from "pino";
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e";
export type OpenAIModelFamily =
| "turbo"
| "gpt4"
| "gpt4-32k"
| "gpt4-turbo"
| "dall-e";
export type AnthropicModelFamily = "claude";
export type GooglePalmModelFamily = "bison";
export type AwsBedrockModelFamily = "aws-claude";
export type AzureOpenAIModelFamily = `azure-${Exclude<
OpenAIModelFamily,
"dall-e"
>}`;
export type ModelFamily =
| OpenAIModelFamily
| AnthropicModelFamily
| GooglePalmModelFamily
| AwsBedrockModelFamily;
| AwsBedrockModelFamily
| AzureOpenAIModelFamily;
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
@ -23,6 +33,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"claude",
"bison",
"aws-claude",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
"azure-gpt4-turbo",
] as const);
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
@ -64,6 +78,24 @@ export function getAwsBedrockModelFamily(_model: string): ModelFamily {
return "aws-claude";
}
export function getAzureOpenAIModelFamily(
model: string,
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
): AzureOpenAIModelFamily {
// Azure model names omit periods. addAzureKey also prepends "azure-" to the
// model name to route the request the correct keyprovider, so we need to
// remove that as well.
const modified = model
.replace("gpt-35-turbo", "gpt-3.5-turbo")
.replace("azure-", "");
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
if (modified.match(regex)) {
return `azure-${family}` as AzureOpenAIModelFamily;
}
}
return defaultFamily;
}
export function assertIsKnownModelFamily(
modelFamily: string
): asserts modelFamily is ModelFamily {

View File

@ -12,6 +12,7 @@ import schedule from "node-schedule";
import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config";
import {
getAzureOpenAIModelFamily,
getClaudeModelFamily,
getGooglePalmModelFamily,
getOpenAIModelFamily,
@ -34,6 +35,10 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
claude: 0,
bison: 0,
"aws-claude": 0,
"azure-turbo": 0,
"azure-gpt4": 0,
"azure-gpt4-turbo": 0,
"azure-gpt4-32k": 0,
};
const users: Map<string, User> = new Map();
@ -382,6 +387,9 @@ function getModelFamilyForQuotaUsage(
model: string,
api: APIFormat
): ModelFamily {
// TODO: this seems incorrect
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
switch (api) {
case "openai":
case "openai-text":