adds azure dall-e support

This commit is contained in:
nai-degen 2024-03-09 13:03:50 -06:00
parent cab346787c
commit cec39328a2
13 changed files with 112 additions and 76 deletions

View File

@ -40,10 +40,10 @@ NODE_ENV=production
# Which model types users are allowed to access.
# The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# By default, all models are allowed except for 'dall-e'.
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' to
# the list.
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-dall-e
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
# 'azure-dall-e' to the list of allowed model families.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# URLs from which requests will be blocked.

View File

@ -12,12 +12,12 @@ import { checkCsrfToken, injectCsrfToken } from "./shared/inject-csrf";
const INFO_PAGE_TTL = 2000;
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"turbo": "GPT-3.5 Turbo",
"gpt4": "GPT-4",
turbo: "GPT-3.5 Turbo",
gpt4: "GPT-4",
"gpt4-32k": "GPT-4 32k",
"gpt4-turbo": "GPT-4 Turbo",
"dall-e": "DALL-E",
"claude": "Claude (Sonnet)",
claude: "Claude (Sonnet)",
"claude-opus": "Claude (Opus)",
"gemini-pro": "Gemini Pro",
"mistral-tiny": "Mistral 7B",
@ -29,6 +29,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"azure-gpt4": "Azure GPT-4",
"azure-gpt4-32k": "Azure GPT-4 32k",
"azure-gpt4-turbo": "Azure GPT-4 Turbo",
"azure-dall-e": "Azure DALL-E",
};
const converter = new showdown.Converter();
@ -125,7 +126,9 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
const wait = info[modelFamily]?.estimatedQueueTime;
if (hasKeys && wait) {
waits.push(`**${MODEL_FAMILY_FRIENDLY_NAME[modelFamily] || modelFamily}**: ${wait}`);
waits.push(
`**${MODEL_FAMILY_FRIENDLY_NAME[modelFamily] || modelFamily}**: ${wait}`
);
}
}
@ -163,9 +166,10 @@ function getServerTitle() {
}
function buildRecentImageSection() {
const dalleModels: ModelFamily[] = ["azure-dall-e", "dall-e"];
if (
!config.allowedModelFamilies.includes("dall-e") ||
!config.showRecentImages
!config.showRecentImages ||
dalleModels.every((f) => !config.allowedModelFamilies.includes(f))
) {
return "";
}
@ -208,7 +212,11 @@ function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
}
}
function checkIfUnlocked(req: Request, res: Response, next: express.NextFunction) {
function checkIfUnlocked(
req: Request,
res: Response,
next: express.NextFunction
) {
if (config.serviceInfoPassword?.length && !req.session?.unlocked) {
return res.redirect("/unlock-info");
}
@ -223,16 +231,13 @@ if (config.serviceInfoPassword?.length) {
);
infoPageRouter.use(withSession);
infoPageRouter.use(injectCsrfToken, checkCsrfToken);
infoPageRouter.post(
"/unlock-info",
(req, res) => {
if (req.body.password !== config.serviceInfoPassword) {
return res.status(403).send("Incorrect password");
}
req.session!.unlocked = true;
res.redirect("/");
},
);
infoPageRouter.post("/unlock-info", (req, res) => {
if (req.body.password !== config.serviceInfoPassword) {
return res.status(403).send("Incorrect password");
}
req.session!.unlocked = true;
res.redirect("/");
});
infoPageRouter.get("/unlock-info", (_req, res) => {
if (_req.session?.unlocked) return res.redirect("/");

View File

@ -124,5 +124,15 @@ azureOpenAIRouter.post(
}),
azureOpenAIProxy
);
azureOpenAIRouter.post(
"/v1/images/generations",
ipLimiter,
createPreprocessorMiddleware({
inApi: "openai-image",
outApi: "openai-image",
service: "azure",
}),
azureOpenAIProxy
);
export const azure = azureOpenAIRouter;

View File

@ -1,8 +1,15 @@
import { AzureOpenAIKey, keyPool } from "../../../../shared/key-management";
import {
APIFormat,
AzureOpenAIKey,
keyPool,
} from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
export const addAzureKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
const validAPIs: APIFormat[] = ["openai", "openai-image"];
const apisValid = [req.outboundApi, req.inboundApi].every((api) =>
validAPIs.includes(api)
);
const serviceValid = req.service === "azure";
if (!apisValid || !serviceValid) {
throw new Error("addAzureKey called on invalid request");
@ -18,7 +25,7 @@ export const addAzureKey: RequestPreprocessor = (req) => {
req.key = keyPool.get(model, "azure");
req.body.model = model;
// Handles the sole Azure API deviation from the OpenAI spec (that I know of)
const notNullOrUndefined = (x: any) => x !== null && x !== undefined;
if ([req.body.logprobs, req.body.top_logprobs].some(notNullOrUndefined)) {
@ -28,7 +35,7 @@ export const addAzureKey: RequestPreprocessor = (req) => {
// req.body.logprobs = req.body.top_logprobs || undefined;
// delete req.body.top_logprobs
// }
// Temporarily just disabling logprobs for Azure because their model support
// is random: `This model does not support the 'logprobs' parameter.`
delete req.body.logprobs;
@ -43,11 +50,16 @@ export const addAzureKey: RequestPreprocessor = (req) => {
const cred = req.key as AzureOpenAIKey;
const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred);
const operation =
req.outboundApi === "openai" ? "/chat/completions" : "/images/generations";
const apiVersion =
req.outboundApi === "openai" ? "2023-09-01-preview" : "2024-02-15-preview";
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: `${resourceName}.openai.azure.com`,
path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`,
path: `/openai/deployments/${deploymentId}${operation}?api-version=${apiVersion}`,
headers: {
["host"]: `${resourceName}.openai.azure.com`,
["content-type"]: "application/json",

View File

@ -1,11 +1,14 @@
import { ProxyResHandlerWithBody } from "./index";
import { mirrorGeneratedImage, OpenAIImageGenerationResult } from "../../../shared/file-storage/mirror-generated-image";
import {
mirrorGeneratedImage,
OpenAIImageGenerationResult,
} from "../../../shared/file-storage/mirror-generated-image";
export const saveImage: ProxyResHandlerWithBody = async (
_proxyRes,
req,
_res,
body,
body
) => {
if (req.outboundApi !== "openai-image") {
return;
@ -18,10 +21,14 @@ export const saveImage: ProxyResHandlerWithBody = async (
if (body.data) {
const baseUrl = req.protocol + "://" + req.get("host");
const prompt = body.data[0].revised_prompt ?? req.body.prompt;
await mirrorGeneratedImage(
const res = await mirrorGeneratedImage(
baseUrl,
prompt,
body as OpenAIImageGenerationResult
);
req.log.info(
{ urls: res.data.map((item) => item.url) },
"Saved generated image to user_content"
);
}
};

View File

@ -13,6 +13,7 @@ import { keyPool } from "./shared/key-management";
import { adminRouter } from "./admin/routes";
import { proxyRouter } from "./proxy/routes";
import { infoPageRouter } from "./info-page";
import { IMAGE_GEN_MODELS } from "./shared/models";
import { userRouter } from "./user/routes";
import { logQueue } from "./shared/prompt-logging";
import { start as startRequestQueue } from "./proxy/queue";
@ -111,7 +112,7 @@ async function start() {
await initTokenizers();
if (config.allowedModelFamilies.includes("dall-e")) {
if (config.allowedModelFamilies.some((f) => IMAGE_GEN_MODELS.includes(f))) {
await setupAssetsDir();
}

View File

@ -6,7 +6,6 @@ import { USER_ASSETS_DIR } from "../../config";
import { addToImageHistory } from "./image-history";
import { libSharp } from "./index";
export type OpenAIImageGenerationResult = {
created: number;
data: {

View File

@ -4,7 +4,7 @@ import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
import { getAzureOpenAIModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com";
const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) =>
`https://${AZURE_HOST.replace(
@ -29,7 +29,7 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
service: "azure",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
recurringChecksEnabled: true,
updateKey,
});
}
@ -43,7 +43,6 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
const data = error.response.data;
const status = data.error.status;
const errorType = data.error.code || data.error.type;
switch (errorType) {
case "DeploymentNotFound":
@ -65,8 +64,9 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
isRevoked: true,
});
case "429":
const headers = error.response.headers;
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
{ key: key.hash, errorType, error: error.response.data, headers },
"Key is rate limited. Rechecking key in 1 minute."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
@ -79,8 +79,9 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
}, 1000 * 60);
return;
default:
const { data: errorData, status: errorStatus } = error.response;
this.log.error(
{ key: key.hash, errorType, error: error.response.data, status },
{ key: key.hash, errorType, errorData, errorStatus },
"Unknown Azure API error while checking key. Please report this."
);
return this.updateKey(key.hash, { lastChecked: Date.now() });
@ -98,7 +99,7 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
const { headers, status, data } = response ?? {};
this.log.error(
{ key: key.hash, status, headers, data, error: error.message },
{ key: key.hash, status, headers, data, error: error.stack },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
@ -115,9 +116,25 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
stream: false,
messages: [{ role: "user", content: "" }],
};
const { data } = await axios.post(url, testRequest, {
const response = await axios.post(url, testRequest, {
headers: { "Content-Type": "application/json", "api-key": apiKey },
validateStatus: (status) => status === 200 || status === 400,
});
const { data } = response;
// We allow one 400 condition, OperationNotSupported, which is returned when
// we try to invoke /chat/completions on dall-e-3. This is expected and
// indicates a DALL-E deployment.
if (response.status === 400) {
if (data.error.code === "OperationNotSupported") return "azure-dall-e";
throw new AxiosError(
`Unexpected error when testing deployment ${deploymentId}`,
"AZURE_TEST_ERROR",
response.config,
response.request,
response
);
}
const family = getAzureOpenAIModelFamily(data.model);

View File

@ -1,14 +1,14 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { HttpError } from "../../errors";
import { logger } from "../../../logger";
import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models";
import { OpenAIModel } from "../openai/provider";
import { AzureOpenAIKeyChecker } from "./checker";
import { HttpError } from "../../errors";
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
export type AzureOpenAIModel = OpenAIModel;
type AzureOpenAIKeyUsage = {
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
@ -75,6 +75,7 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
"azure-gpt4Tokens": 0,
"azure-gpt4-32kTokens": 0,
"azure-gpt4-turboTokens": 0,
"azure-dall-eTokens": 0,
};
this.keys.push(newKey);
}

View File

@ -30,10 +30,7 @@ export type MistralAIModelFamily =
| "mistral-medium"
| "mistral-large";
export type AwsBedrockModelFamily = "aws-claude";
export type AzureOpenAIModelFamily = `azure-${Exclude<
OpenAIModelFamily,
"dall-e"
>}`;
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
export type ModelFamily =
| OpenAIModelFamily
| AnthropicModelFamily
@ -62,6 +59,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"azure-gpt4",
"azure-gpt4-32k",
"azure-gpt4-turbo",
"azure-dall-e",
] as const);
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
@ -103,6 +101,7 @@ export const MODEL_FAMILY_SERVICE: {
"azure-gpt4": "azure",
"azure-gpt4-32k": "azure",
"azure-gpt4-turbo": "azure",
"azure-dall-e": "azure",
"gemini-pro": "google-ai",
"mistral-tiny": "mistral-ai",
"mistral-small": "mistral-ai",
@ -110,6 +109,8 @@ export const MODEL_FAMILY_SERVICE: {
"mistral-large": "mistral-ai",
};
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
pino({ level: "debug" }).child({ module: "startup" });
export function getOpenAIModelFamily(

View File

@ -22,7 +22,7 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
case "turbo":
cost = 0.000001;
break;
case "dall-e":
case "azure-dall-e":
cost = 0.00001;
break;
case "aws-claude":

View File

@ -1,17 +1,15 @@
import { ZodType, z } from "zod";
import type { ModelFamily } from "../models";
import { MODEL_FAMILIES, ModelFamily } from "../models";
import { makeOptionalPropsNullable } from "../utils";
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
turbo: z.number().optional().default(0),
gpt4: z.number().optional().default(0),
"gpt4-32k": z.number().optional().default(0),
"gpt4-turbo": z.number().optional().default(0),
"dall-e": z.number().optional().default(0),
claude: z.number().optional().default(0),
"gemini-pro": z.number().optional().default(0),
"aws-claude": z.number().optional().default(0),
});
// This just dynamically creates a Zod object type with a key for each model
// family and an optional number value.
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object(
MODEL_FAMILIES.reduce(
(acc, family) => ({ ...acc, [family]: z.number().optional().default(0) }),
{} as Record<ModelFamily, ZodType<number>>
)
);
export const UserSchema = z
.object({
@ -66,7 +64,7 @@ export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
.extend({ token: z.string() });
export type UserTokenCounts = {
[K in ModelFamily]?: number;
[K in ModelFamily]: number | undefined;
};
export type User = z.infer<typeof UserSchema>;
export type UserUpdate = z.infer<typeof UserPartialSchema>;

View File

@ -28,25 +28,10 @@ import { assertNever } from "../utils";
const log = logger.child({ module: "users" });
const INITIAL_TOKENS: Required<UserTokenCounts> = {
turbo: 0,
gpt4: 0,
"gpt4-32k": 0,
"gpt4-turbo": 0,
"dall-e": 0,
claude: 0,
"claude-opus": 0,
"gemini-pro": 0,
"mistral-tiny": 0,
"mistral-small": 0,
"mistral-medium": 0,
"mistral-large": 0,
"aws-claude": 0,
"azure-turbo": 0,
"azure-gpt4": 0,
"azure-gpt4-turbo": 0,
"azure-gpt4-32k": 0,
};
const INITIAL_TOKENS: Required<UserTokenCounts> = MODEL_FAMILIES.reduce(
(acc, family) => ({ ...acc, [family]: 0 }),
{} as Record<ModelFamily, number>
);
const users: Map<string, User> = new Map();
const usersToFlush = new Set<string>();