adds option to disable multimodal prompts

This commit is contained in:
nai-degen 2024-03-23 14:30:08 -05:00
parent 8cb960e174
commit 34a673a80a
7 changed files with 73 additions and 12 deletions

View File

@ -249,6 +249,14 @@ type Config = {
* risk.
*/
allowOpenAIToolUsage?: boolean;
/**
* Whether to allow prompts containing images, for use with multimodal models.
* Avoid giving this to untrusted users, as they can submit illegal content.
*
* Applies to GPT-4 Vision and Claude Vision. Users with `special` role are
* exempt from this restriction.
*/
allowImagePrompts?: boolean;
/**
* Allows overriding the default proxy endpoint route. Defaults to /proxy.
* A leading slash is required.
@ -348,6 +356,7 @@ export const config: Config = {
staticServiceInfo: getEnvWithDefault("STATIC_SERVICE_INFO", false),
trustedProxies: getEnvWithDefault("TRUSTED_PROXIES", 1),
allowOpenAIToolUsage: getEnvWithDefault("ALLOW_OPENAI_TOOL_USAGE", false),
allowImagePrompts: getEnvWithDefault("ALLOW_IMAGE_PROMPTS", false),
proxyEndpointRoute: getEnvWithDefault("PROXY_ENDPOINT_ROUTE", "/proxy"),
} as const;

View File

@ -11,16 +11,17 @@ export {
// Express middleware (runs before http-proxy-middleware, can be async)
export { addAzureKey } from "./preprocessors/add-azure-key";
export { applyQuotaLimits } from "./preprocessors/apply-quota-limits";
export { validateContextSize } from "./preprocessors/validate-context-size";
export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
export { languageFilter } from "./preprocessors/language-filter";
export { setApiFormat } from "./preprocessors/set-api-format";
export { signAwsRequest } from "./preprocessors/sign-aws-request";
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
export { validateContextSize } from "./preprocessors/validate-context-size";
export { validateVision } from "./preprocessors/validate-vision";
// http-proxy-middleware callbacks (runs on onProxyReq, cannot be async)
export { addKey, addKeyForEmbeddingsRequest } from "./onproxyreq/add-key";
export { addAnthropicPreamble } from "./onproxyreq/add-anthropic-preamble";
export { addKey, addKeyForEmbeddingsRequest } from "./onproxyreq/add-key";
export { blockZoomerOrigins } from "./onproxyreq/block-zoomer-origins";
export { checkModelFamily } from "./onproxyreq/check-model-family";
export { finalizeBody } from "./onproxyreq/finalize-body";

View File

@ -1,4 +1,5 @@
import { AnthropicChatMessage } from "../../../../shared/api-schemas";
import { containsImageContent } from "../../../../shared/api-schemas/anthropic";
import { Key, OpenAIKey, keyPool } from "../../../../shared/key-management";
import { isEmbeddingsRequest } from "../../common";
import { HPMRequestCallback } from "../index";
@ -22,7 +23,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
let needsMultimodal = false;
if (outboundApi === "anthropic-chat") {
needsMultimodal = needsMultimodalKey(
needsMultimodal = containsImageContent(
body.messages as AnthropicChatMessage[]
);
}
@ -122,10 +123,3 @@ export const addKeyForEmbeddingsRequest: HPMRequestCallback = (
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
}
};
function needsMultimodalKey(messages: AnthropicChatMessage[]) {
return messages.some(
({ content }) =>
typeof content !== "string" && content.some((c) => c.type === "image")
);
}

View File

@ -4,11 +4,12 @@ import { initializeSseStream } from "../../../shared/streaming";
import { classifyErrorAndSend } from "../common";
import {
RequestPreprocessor,
validateContextSize,
countPromptTokens,
languageFilter,
setApiFormat,
transformOutboundPayload,
languageFilter,
validateContextSize,
validateVision,
} from ".";
type RequestPreprocessorOptions = {
@ -50,6 +51,7 @@ export const createPreprocessorMiddleware = (
languageFilter,
...(afterTransform ?? []),
validateContextSize,
validateVision,
];
return async (...args) => executePreprocessors(preprocessors, args);
};

View File

@ -0,0 +1,38 @@
import { config } from "../../../../config";
import { assertNever } from "../../../../shared/utils";
import { RequestPreprocessor } from "../index";
import { containsImageContent as containsImageContentOpenAI } from "../../../../shared/api-schemas/openai";
import { containsImageContent as containsImageContentAnthropic } from "../../../../shared/api-schemas/anthropic";
import { ForbiddenError } from "../../../../shared/errors";
/**
* Rejects prompts containing images if multimodal prompts are disabled.
*/
export const validateVision: RequestPreprocessor = async (req) => {
if (config.allowImagePrompts) return;
if (req.user?.type === "special") return;
let hasImage = false;
switch (req.outboundApi) {
case "openai":
hasImage = containsImageContentOpenAI(req.body.messages);
break;
case "anthropic-chat":
hasImage = containsImageContentAnthropic(req.body.messages);
break;
case "anthropic-text":
case "google-ai":
case "mistral-ai":
case "openai-image":
case "openai-text":
return;
default:
assertNever(req.outboundApi);
}
if (hasImage) {
throw new ForbiddenError(
"Prompts containing images are not permitted. Disable 'Send Inline Images' in your client and try again."
);
}
};

View File

@ -438,3 +438,10 @@ function convertOpenAIContent(
}
});
}
export function containsImageContent(messages: AnthropicChatMessage[]) {
return messages.some(
({ content }) =>
typeof content !== "string" && content.some((c) => c.type === "image")
);
}

View File

@ -131,3 +131,13 @@ export function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
}
}
export function containsImageContent(
messages: OpenAIChatMessage[]
): boolean {
return messages.some((m) =>
Array.isArray(m.content)
? m.content.some((contentItem) => "image_url" in contentItem)
: false
);
}