adds azure dall-e support
This commit is contained in:
parent
cab346787c
commit
cec39328a2
|
@ -40,10 +40,10 @@ NODE_ENV=production
|
|||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# By default, all models are allowed except for 'dall-e'.
|
||||
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' to
|
||||
# the list.
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-dall-e
|
||||
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
|
||||
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
|
||||
# 'azure-dall-e' to the list of allowed model families.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
|
||||
# URLs from which requests will be blocked.
|
||||
|
|
|
@ -12,12 +12,12 @@ import { checkCsrfToken, injectCsrfToken } from "./shared/inject-csrf";
|
|||
|
||||
const INFO_PAGE_TTL = 2000;
|
||||
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
"turbo": "GPT-3.5 Turbo",
|
||||
"gpt4": "GPT-4",
|
||||
turbo: "GPT-3.5 Turbo",
|
||||
gpt4: "GPT-4",
|
||||
"gpt4-32k": "GPT-4 32k",
|
||||
"gpt4-turbo": "GPT-4 Turbo",
|
||||
"dall-e": "DALL-E",
|
||||
"claude": "Claude (Sonnet)",
|
||||
claude: "Claude (Sonnet)",
|
||||
"claude-opus": "Claude (Opus)",
|
||||
"gemini-pro": "Gemini Pro",
|
||||
"mistral-tiny": "Mistral 7B",
|
||||
|
@ -29,6 +29,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
|||
"azure-gpt4": "Azure GPT-4",
|
||||
"azure-gpt4-32k": "Azure GPT-4 32k",
|
||||
"azure-gpt4-turbo": "Azure GPT-4 Turbo",
|
||||
"azure-dall-e": "Azure DALL-E",
|
||||
};
|
||||
|
||||
const converter = new showdown.Converter();
|
||||
|
@ -125,7 +126,9 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
|
|||
|
||||
const wait = info[modelFamily]?.estimatedQueueTime;
|
||||
if (hasKeys && wait) {
|
||||
waits.push(`**${MODEL_FAMILY_FRIENDLY_NAME[modelFamily] || modelFamily}**: ${wait}`);
|
||||
waits.push(
|
||||
`**${MODEL_FAMILY_FRIENDLY_NAME[modelFamily] || modelFamily}**: ${wait}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,9 +166,10 @@ function getServerTitle() {
|
|||
}
|
||||
|
||||
function buildRecentImageSection() {
|
||||
const dalleModels: ModelFamily[] = ["azure-dall-e", "dall-e"];
|
||||
if (
|
||||
!config.allowedModelFamilies.includes("dall-e") ||
|
||||
!config.showRecentImages
|
||||
!config.showRecentImages ||
|
||||
dalleModels.every((f) => !config.allowedModelFamilies.includes(f))
|
||||
) {
|
||||
return "";
|
||||
}
|
||||
|
@ -208,7 +212,11 @@ function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
|
|||
}
|
||||
}
|
||||
|
||||
function checkIfUnlocked(req: Request, res: Response, next: express.NextFunction) {
|
||||
function checkIfUnlocked(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: express.NextFunction
|
||||
) {
|
||||
if (config.serviceInfoPassword?.length && !req.session?.unlocked) {
|
||||
return res.redirect("/unlock-info");
|
||||
}
|
||||
|
@ -223,16 +231,13 @@ if (config.serviceInfoPassword?.length) {
|
|||
);
|
||||
infoPageRouter.use(withSession);
|
||||
infoPageRouter.use(injectCsrfToken, checkCsrfToken);
|
||||
infoPageRouter.post(
|
||||
"/unlock-info",
|
||||
(req, res) => {
|
||||
if (req.body.password !== config.serviceInfoPassword) {
|
||||
return res.status(403).send("Incorrect password");
|
||||
}
|
||||
req.session!.unlocked = true;
|
||||
res.redirect("/");
|
||||
},
|
||||
);
|
||||
infoPageRouter.post("/unlock-info", (req, res) => {
|
||||
if (req.body.password !== config.serviceInfoPassword) {
|
||||
return res.status(403).send("Incorrect password");
|
||||
}
|
||||
req.session!.unlocked = true;
|
||||
res.redirect("/");
|
||||
});
|
||||
infoPageRouter.get("/unlock-info", (_req, res) => {
|
||||
if (_req.session?.unlocked) return res.redirect("/");
|
||||
|
||||
|
|
|
@ -124,5 +124,15 @@ azureOpenAIRouter.post(
|
|||
}),
|
||||
azureOpenAIProxy
|
||||
);
|
||||
azureOpenAIRouter.post(
|
||||
"/v1/images/generations",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "openai-image",
|
||||
outApi: "openai-image",
|
||||
service: "azure",
|
||||
}),
|
||||
azureOpenAIProxy
|
||||
);
|
||||
|
||||
export const azure = azureOpenAIRouter;
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
import { AzureOpenAIKey, keyPool } from "../../../../shared/key-management";
|
||||
import {
|
||||
APIFormat,
|
||||
AzureOpenAIKey,
|
||||
keyPool,
|
||||
} from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const addAzureKey: RequestPreprocessor = (req) => {
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
|
||||
const validAPIs: APIFormat[] = ["openai", "openai-image"];
|
||||
const apisValid = [req.outboundApi, req.inboundApi].every((api) =>
|
||||
validAPIs.includes(api)
|
||||
);
|
||||
const serviceValid = req.service === "azure";
|
||||
if (!apisValid || !serviceValid) {
|
||||
throw new Error("addAzureKey called on invalid request");
|
||||
|
@ -18,7 +25,7 @@ export const addAzureKey: RequestPreprocessor = (req) => {
|
|||
|
||||
req.key = keyPool.get(model, "azure");
|
||||
req.body.model = model;
|
||||
|
||||
|
||||
// Handles the sole Azure API deviation from the OpenAI spec (that I know of)
|
||||
const notNullOrUndefined = (x: any) => x !== null && x !== undefined;
|
||||
if ([req.body.logprobs, req.body.top_logprobs].some(notNullOrUndefined)) {
|
||||
|
@ -28,7 +35,7 @@ export const addAzureKey: RequestPreprocessor = (req) => {
|
|||
// req.body.logprobs = req.body.top_logprobs || undefined;
|
||||
// delete req.body.top_logprobs
|
||||
// }
|
||||
|
||||
|
||||
// Temporarily just disabling logprobs for Azure because their model support
|
||||
// is random: `This model does not support the 'logprobs' parameter.`
|
||||
delete req.body.logprobs;
|
||||
|
@ -43,11 +50,16 @@ export const addAzureKey: RequestPreprocessor = (req) => {
|
|||
const cred = req.key as AzureOpenAIKey;
|
||||
const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred);
|
||||
|
||||
const operation =
|
||||
req.outboundApi === "openai" ? "/chat/completions" : "/images/generations";
|
||||
const apiVersion =
|
||||
req.outboundApi === "openai" ? "2023-09-01-preview" : "2024-02-15-preview";
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: `${resourceName}.openai.azure.com`,
|
||||
path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`,
|
||||
path: `/openai/deployments/${deploymentId}${operation}?api-version=${apiVersion}`,
|
||||
headers: {
|
||||
["host"]: `${resourceName}.openai.azure.com`,
|
||||
["content-type"]: "application/json",
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import { ProxyResHandlerWithBody } from "./index";
|
||||
import { mirrorGeneratedImage, OpenAIImageGenerationResult } from "../../../shared/file-storage/mirror-generated-image";
|
||||
import {
|
||||
mirrorGeneratedImage,
|
||||
OpenAIImageGenerationResult,
|
||||
} from "../../../shared/file-storage/mirror-generated-image";
|
||||
|
||||
export const saveImage: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
_res,
|
||||
body,
|
||||
body
|
||||
) => {
|
||||
if (req.outboundApi !== "openai-image") {
|
||||
return;
|
||||
|
@ -18,10 +21,14 @@ export const saveImage: ProxyResHandlerWithBody = async (
|
|||
if (body.data) {
|
||||
const baseUrl = req.protocol + "://" + req.get("host");
|
||||
const prompt = body.data[0].revised_prompt ?? req.body.prompt;
|
||||
await mirrorGeneratedImage(
|
||||
const res = await mirrorGeneratedImage(
|
||||
baseUrl,
|
||||
prompt,
|
||||
body as OpenAIImageGenerationResult
|
||||
);
|
||||
req.log.info(
|
||||
{ urls: res.data.map((item) => item.url) },
|
||||
"Saved generated image to user_content"
|
||||
);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -13,6 +13,7 @@ import { keyPool } from "./shared/key-management";
|
|||
import { adminRouter } from "./admin/routes";
|
||||
import { proxyRouter } from "./proxy/routes";
|
||||
import { infoPageRouter } from "./info-page";
|
||||
import { IMAGE_GEN_MODELS } from "./shared/models";
|
||||
import { userRouter } from "./user/routes";
|
||||
import { logQueue } from "./shared/prompt-logging";
|
||||
import { start as startRequestQueue } from "./proxy/queue";
|
||||
|
@ -111,7 +112,7 @@ async function start() {
|
|||
|
||||
await initTokenizers();
|
||||
|
||||
if (config.allowedModelFamilies.includes("dall-e")) {
|
||||
if (config.allowedModelFamilies.some((f) => IMAGE_GEN_MODELS.includes(f))) {
|
||||
await setupAssetsDir();
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import { USER_ASSETS_DIR } from "../../config";
|
|||
import { addToImageHistory } from "./image-history";
|
||||
import { libSharp } from "./index";
|
||||
|
||||
|
||||
export type OpenAIImageGenerationResult = {
|
||||
created: number;
|
||||
data: {
|
||||
|
|
|
@ -4,7 +4,7 @@ import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
|
|||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
|
||||
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
|
||||
const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com";
|
||||
const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) =>
|
||||
`https://${AZURE_HOST.replace(
|
||||
|
@ -29,7 +29,7 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
|||
service: "azure",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
recurringChecksEnabled: true,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
@ -43,7 +43,6 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
|||
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
|
||||
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
|
||||
const data = error.response.data;
|
||||
const status = data.error.status;
|
||||
const errorType = data.error.code || data.error.type;
|
||||
switch (errorType) {
|
||||
case "DeploymentNotFound":
|
||||
|
@ -65,8 +64,9 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
|||
isRevoked: true,
|
||||
});
|
||||
case "429":
|
||||
const headers = error.response.headers;
|
||||
this.log.warn(
|
||||
{ key: key.hash, errorType, error: error.response.data },
|
||||
{ key: key.hash, errorType, error: error.response.data, headers },
|
||||
"Key is rate limited. Rechecking key in 1 minute."
|
||||
);
|
||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
|
@ -79,8 +79,9 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
|||
}, 1000 * 60);
|
||||
return;
|
||||
default:
|
||||
const { data: errorData, status: errorStatus } = error.response;
|
||||
this.log.error(
|
||||
{ key: key.hash, errorType, error: error.response.data, status },
|
||||
{ key: key.hash, errorType, errorData, errorStatus },
|
||||
"Unknown Azure API error while checking key. Please report this."
|
||||
);
|
||||
return this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
|
@ -98,7 +99,7 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
|||
|
||||
const { headers, status, data } = response ?? {};
|
||||
this.log.error(
|
||||
{ key: key.hash, status, headers, data, error: error.message },
|
||||
{ key: key.hash, status, headers, data, error: error.stack },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
|
@ -115,9 +116,25 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
|||
stream: false,
|
||||
messages: [{ role: "user", content: "" }],
|
||||
};
|
||||
const { data } = await axios.post(url, testRequest, {
|
||||
const response = await axios.post(url, testRequest, {
|
||||
headers: { "Content-Type": "application/json", "api-key": apiKey },
|
||||
validateStatus: (status) => status === 200 || status === 400,
|
||||
});
|
||||
const { data } = response;
|
||||
|
||||
// We allow one 400 condition, OperationNotSupported, which is returned when
|
||||
// we try to invoke /chat/completions on dall-e-3. This is expected and
|
||||
// indicates a DALL-E deployment.
|
||||
if (response.status === 400) {
|
||||
if (data.error.code === "OperationNotSupported") return "azure-dall-e";
|
||||
throw new AxiosError(
|
||||
`Unexpected error when testing deployment ${deploymentId}`,
|
||||
"AZURE_TEST_ERROR",
|
||||
response.config,
|
||||
response.request,
|
||||
response
|
||||
);
|
||||
}
|
||||
|
||||
const family = getAzureOpenAIModelFamily(data.model);
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { HttpError } from "../../errors";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AzureOpenAIModelFamily } from "../../models";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
import { OpenAIModel } from "../openai/provider";
|
||||
import { AzureOpenAIKeyChecker } from "./checker";
|
||||
import { HttpError } from "../../errors";
|
||||
|
||||
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
|
||||
export type AzureOpenAIModel = OpenAIModel;
|
||||
|
||||
type AzureOpenAIKeyUsage = {
|
||||
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
|
||||
|
@ -75,6 +75,7 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
|||
"azure-gpt4Tokens": 0,
|
||||
"azure-gpt4-32kTokens": 0,
|
||||
"azure-gpt4-turboTokens": 0,
|
||||
"azure-dall-eTokens": 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
|
|
|
@ -30,10 +30,7 @@ export type MistralAIModelFamily =
|
|||
| "mistral-medium"
|
||||
| "mistral-large";
|
||||
export type AwsBedrockModelFamily = "aws-claude";
|
||||
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
OpenAIModelFamily,
|
||||
"dall-e"
|
||||
>}`;
|
||||
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
|
@ -62,6 +59,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
|||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-dall-e",
|
||||
] as const);
|
||||
|
||||
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||
|
@ -103,6 +101,7 @@ export const MODEL_FAMILY_SERVICE: {
|
|||
"azure-gpt4": "azure",
|
||||
"azure-gpt4-32k": "azure",
|
||||
"azure-gpt4-turbo": "azure",
|
||||
"azure-dall-e": "azure",
|
||||
"gemini-pro": "google-ai",
|
||||
"mistral-tiny": "mistral-ai",
|
||||
"mistral-small": "mistral-ai",
|
||||
|
@ -110,6 +109,8 @@ export const MODEL_FAMILY_SERVICE: {
|
|||
"mistral-large": "mistral-ai",
|
||||
};
|
||||
|
||||
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
|
||||
|
||||
pino({ level: "debug" }).child({ module: "startup" });
|
||||
|
||||
export function getOpenAIModelFamily(
|
||||
|
|
|
@ -22,7 +22,7 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
|||
case "turbo":
|
||||
cost = 0.000001;
|
||||
break;
|
||||
case "dall-e":
|
||||
case "azure-dall-e":
|
||||
cost = 0.00001;
|
||||
break;
|
||||
case "aws-claude":
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
import { ZodType, z } from "zod";
|
||||
import type { ModelFamily } from "../models";
|
||||
import { MODEL_FAMILIES, ModelFamily } from "../models";
|
||||
import { makeOptionalPropsNullable } from "../utils";
|
||||
|
||||
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
|
||||
turbo: z.number().optional().default(0),
|
||||
gpt4: z.number().optional().default(0),
|
||||
"gpt4-32k": z.number().optional().default(0),
|
||||
"gpt4-turbo": z.number().optional().default(0),
|
||||
"dall-e": z.number().optional().default(0),
|
||||
claude: z.number().optional().default(0),
|
||||
"gemini-pro": z.number().optional().default(0),
|
||||
"aws-claude": z.number().optional().default(0),
|
||||
});
|
||||
// This just dynamically creates a Zod object type with a key for each model
|
||||
// family and an optional number value.
|
||||
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object(
|
||||
MODEL_FAMILIES.reduce(
|
||||
(acc, family) => ({ ...acc, [family]: z.number().optional().default(0) }),
|
||||
{} as Record<ModelFamily, ZodType<number>>
|
||||
)
|
||||
);
|
||||
|
||||
export const UserSchema = z
|
||||
.object({
|
||||
|
@ -66,7 +64,7 @@ export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
|
|||
.extend({ token: z.string() });
|
||||
|
||||
export type UserTokenCounts = {
|
||||
[K in ModelFamily]?: number;
|
||||
[K in ModelFamily]: number | undefined;
|
||||
};
|
||||
export type User = z.infer<typeof UserSchema>;
|
||||
export type UserUpdate = z.infer<typeof UserPartialSchema>;
|
||||
|
|
|
@ -28,25 +28,10 @@ import { assertNever } from "../utils";
|
|||
|
||||
const log = logger.child({ module: "users" });
|
||||
|
||||
const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
||||
turbo: 0,
|
||||
gpt4: 0,
|
||||
"gpt4-32k": 0,
|
||||
"gpt4-turbo": 0,
|
||||
"dall-e": 0,
|
||||
claude: 0,
|
||||
"claude-opus": 0,
|
||||
"gemini-pro": 0,
|
||||
"mistral-tiny": 0,
|
||||
"mistral-small": 0,
|
||||
"mistral-medium": 0,
|
||||
"mistral-large": 0,
|
||||
"aws-claude": 0,
|
||||
"azure-turbo": 0,
|
||||
"azure-gpt4": 0,
|
||||
"azure-gpt4-turbo": 0,
|
||||
"azure-gpt4-32k": 0,
|
||||
};
|
||||
const INITIAL_TOKENS: Required<UserTokenCounts> = MODEL_FAMILIES.reduce(
|
||||
(acc, family) => ({ ...acc, [family]: 0 }),
|
||||
{} as Record<ModelFamily, number>
|
||||
);
|
||||
|
||||
const users: Map<string, User> = new Map();
|
||||
const usersToFlush = new Set<string>();
|
||||
|
|
Loading…
Reference in New Issue