adds REJECT_PHRASES configuration setting

This commit is contained in:
nai-degen 2023-11-09 16:24:49 -06:00
parent 79e1fe09e4
commit e9110611fa
10 changed files with 69 additions and 47 deletions

View File

@ -43,8 +43,10 @@
# Destination to redirect blocked requests to.
# BLOCK_REDIRECT="https://roblox.com/"
# Whether to reject requests containing disallowed content.
# REJECT_DISALLOWED=false
# Comma-separated list of phrases that will be rejected. Only whole words are matched.
# Surround phrases with quotes if they contain commas.
# Avoid short or common phrases as this tests the entire prompt.
# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four"
# Message to show when requests are rejected.
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."

View File

@ -95,10 +95,10 @@ type Config = {
maxOutputTokensOpenAI: number;
/** For Anthropic, the maximum number of sampled tokens a user can request. */
maxOutputTokensAnthropic: number;
/** Whether requests containing disallowed characters should be rejected. */
rejectDisallowed?: boolean;
/** Whether requests containing the following phrases should be rejected. */
rejectPhrases: string[];
/** Message to return when rejecting requests. */
rejectMessage?: string;
rejectMessage: string;
/** Verbosity level of diagnostic logging. */
logLevel: "trace" | "debug" | "info" | "warn" | "error";
/**
@ -203,7 +203,7 @@ export const config: Config = {
"bison",
"aws-claude",
]),
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
rejectMessage: getEnvWithDefault(
"REJECT_MESSAGE",
"This content violates /aicg/'s acceptable use policy."
@ -321,6 +321,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"proxyKey",
"adminKey",
"checkKeys",
"rejectPhrases",
"showTokenCosts",
"googleSheetsKey",
"firebaseKey",
@ -421,3 +422,11 @@ export function getFirebaseApp(): firebase.app.App {
}
return firebaseApp;
}
function parseCsv(val: string): string[] {
if (!val) return [];
const regex = /(".*?"|[^",]+)(?=\s*,|\s*$)/g;
const matches = val.match(regex) || [];
return matches.map(item => item.replace(/^"|"$/g, '').trim());
}

View File

@ -12,7 +12,6 @@ import {
blockZoomerOrigins,
createPreprocessorMiddleware,
finalizeBody,
languageFilter,
stripHeaders,
createOnProxyReqHandler,
} from "./middleware/request";
@ -142,7 +141,6 @@ const anthropicProxy = createQueueMiddleware({
applyQuotaLimits,
addKey,
addAnthropicPreamble,
languageFilter,
blockZoomerOrigins,
stripHeaders,
finalizeBody,

View File

@ -13,7 +13,6 @@ import {
signAwsRequest,
finalizeAwsRequest,
createOnProxyReqHandler,
languageFilter,
blockZoomerOrigins,
} from "./middleware/request";
import {
@ -134,7 +133,6 @@ const awsProxy = createQueueMiddleware({
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
languageFilter,
blockZoomerOrigins,
stripHeaders,
finalizeAwsRequest,

View File

@ -90,7 +90,7 @@ function classifyError(err: Error): {
} & Record<string, any> {
const defaultError = {
status: 500,
userMessage: `Reverse proxy encountered an unexpected error. (${err.message})`,
userMessage: `Reverse proxy error: ${err.message}`,
type: "proxy_internal_error",
stack: err.stack,
};

View File

@ -12,6 +12,7 @@ export {
export { applyQuotaLimits } from "./apply-quota-limits";
export { validateContextSize } from "./validate-context-size";
export { countPromptTokens } from "./count-prompt-tokens";
export { languageFilter } from "./language-filter";
export { setApiFormat } from "./set-api-format";
export { signAwsRequest } from "./sign-aws-request";
export { transformOutboundPayload } from "./transform-outbound-payload";
@ -22,7 +23,6 @@ export { addAnthropicPreamble } from "./add-anthropic-preamble";
export { blockZoomerOrigins } from "./block-zoomer-origins";
export { finalizeBody } from "./finalize-body";
export { finalizeAwsRequest } from "./finalize-aws-request";
export { languageFilter } from "./language-filter";
export { limitCompletions } from "./limit-completions";
export { stripHeaders } from "./strip-headers";

View File

@ -1,38 +1,49 @@
import { Request } from "express";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { assertNever } from "../../../shared/utils";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
import { RequestPreprocessor } from ".";
import { UserInputError } from "../../../shared/errors";
const DISALLOWED_REGEX =
/[\u2E80-\u2E99\u2E9B-\u2EF3\u2F00-\u2FD5\u3005\u3007\u3021-\u3029\u3038-\u303B\u3400-\u4DB5\u4E00-\u9FD5\uF900-\uFA6D\uFA70-\uFAD9]/;
const rejectedClients = new Map<string, number>();
// Our shitty free-tier VMs will fall over if we test every single character in
// each 15k character request ten times a second. So we'll just sample 20% of
// the characters and hope that's enough.
const containsDisallowedCharacters = (text: string) => {
const sampleSize = Math.ceil(text.length * 0.2);
const sample = text
.split("")
.sort(() => 0.5 - Math.random())
.slice(0, sampleSize)
.join("");
return DISALLOWED_REGEX.test(sample);
};
console.log(config.rejectPhrases);
/** Block requests containing too many disallowed characters. */
export const languageFilter: ProxyRequestMiddleware = (_proxyReq, req) => {
if (!config.rejectDisallowed) {
return;
}
if (isCompletionRequest(req)) {
const combinedText = getPromptFromRequest(req);
if (containsDisallowedCharacters(combinedText)) {
logger.warn(`Blocked request containing bad characters`);
_proxyReq.destroy(new Error(config.rejectMessage));
setInterval(() => {
rejectedClients.forEach((count, ip) => {
if (count > 0) {
rejectedClients.set(ip, Math.floor(count / 2));
} else {
rejectedClients.delete(ip);
}
});
}, 30000);
/**
* Block requests containing blacklisted phrases. Repeated rejections from the
* same IP address will be throttled.
*/
export const languageFilter: RequestPreprocessor = async (req) => {
if (!config.rejectPhrases.length) return;
const prompt = getPromptFromRequest(req);
const match = config.rejectPhrases.find((phrase) =>
prompt.match(new RegExp(phrase, "i"))
);
if (match) {
const ip = req.ip;
const rejections = (rejectedClients.get(req.ip) || 0) + 1;
const delay = Math.min(60000, Math.pow(2, rejections - 1) * 1000);
rejectedClients.set(ip, rejections);
req.log.warn(
{ match, ip, rejections, delay },
"Prompt contains rejected phrase"
);
await new Promise((resolve) => {
req.res!.once("close", resolve);
setTimeout(resolve, delay);
});
throw new UserInputError(config.rejectMessage);
}
};
@ -44,8 +55,10 @@ function getPromptFromRequest(req: Request) {
return body.prompt;
case "openai":
return body.messages
.map((m: { content: string }) => m.content)
.join("\n");
.map(
(m: { content: string; role: string }) => `${m.role}: ${m.content}`
)
.join("\n\n");
case "openai-text":
return body.prompt;
case "google-palm":

View File

@ -7,6 +7,7 @@ import {
countPromptTokens,
setApiFormat,
transformOutboundPayload,
languageFilter,
} from ".";
import { ZodIssue } from "zod";
@ -38,6 +39,7 @@ export const createPreprocessorMiddleware = (
...(beforeTransform ?? []),
transformOutboundPayload,
countPromptTokens,
languageFilter,
...(afterTransform ?? []),
validateContextSize,
];
@ -81,7 +83,11 @@ async function executePreprocessors(
// stream yet as that is typically done later by the request queue. We'll
// do that here and then call classifyErrorAndSend to use the streaming
// error handler.
initializeSseStream(res);
const { stream } = req.body;
const isStreaming = stream === "true" || stream === true;
if (isStreaming && !res.headersSent) {
initializeSseStream(res);
}
classifyErrorAndSend(error as Error, req, res);
}
}

View File

@ -21,7 +21,6 @@ import {
createPreprocessorMiddleware,
finalizeBody,
forceModel,
languageFilter,
limitCompletions,
stripHeaders,
createOnProxyReqHandler,
@ -175,7 +174,6 @@ const openaiProxy = createQueueMiddleware({
pipeline: [
applyQuotaLimits,
addKey,
languageFilter,
limitCompletions,
blockZoomerOrigins,
stripHeaders,

View File

@ -15,7 +15,6 @@ import {
createPreprocessorMiddleware,
finalizeBody,
forceModel,
languageFilter,
stripHeaders,
} from "./middleware/request";
import {
@ -155,7 +154,6 @@ const googlePalmProxy = createQueueMiddleware({
pipeline: [
applyQuotaLimits,
addKey,
languageFilter,
blockZoomerOrigins,
stripHeaders,
finalizeBody,