adds REJECT_PHRASES configuration setting
This commit is contained in:
parent
79e1fe09e4
commit
e9110611fa
|
@ -43,8 +43,10 @@
|
|||
# Destination to redirect blocked requests to.
|
||||
# BLOCK_REDIRECT="https://roblox.com/"
|
||||
|
||||
# Whether to reject requests containing disallowed content.
|
||||
# REJECT_DISALLOWED=false
|
||||
# Comma-separated list of phrases that will be rejected. Only whole words are matched.
|
||||
# Surround phrases with quotes if they contain commas.
|
||||
# Avoid short or common phrases as this tests the entire prompt.
|
||||
# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four"
|
||||
# Message to show when requests are rejected.
|
||||
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
|
||||
|
||||
|
|
|
@ -95,10 +95,10 @@ type Config = {
|
|||
maxOutputTokensOpenAI: number;
|
||||
/** For Anthropic, the maximum number of sampled tokens a user can request. */
|
||||
maxOutputTokensAnthropic: number;
|
||||
/** Whether requests containing disallowed characters should be rejected. */
|
||||
rejectDisallowed?: boolean;
|
||||
/** Whether requests containing the following phrases should be rejected. */
|
||||
rejectPhrases: string[];
|
||||
/** Message to return when rejecting requests. */
|
||||
rejectMessage?: string;
|
||||
rejectMessage: string;
|
||||
/** Verbosity level of diagnostic logging. */
|
||||
logLevel: "trace" | "debug" | "info" | "warn" | "error";
|
||||
/**
|
||||
|
@ -203,7 +203,7 @@ export const config: Config = {
|
|||
"bison",
|
||||
"aws-claude",
|
||||
]),
|
||||
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
|
||||
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
|
||||
rejectMessage: getEnvWithDefault(
|
||||
"REJECT_MESSAGE",
|
||||
"This content violates /aicg/'s acceptable use policy."
|
||||
|
@ -321,6 +321,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
|||
"proxyKey",
|
||||
"adminKey",
|
||||
"checkKeys",
|
||||
"rejectPhrases",
|
||||
"showTokenCosts",
|
||||
"googleSheetsKey",
|
||||
"firebaseKey",
|
||||
|
@ -421,3 +422,11 @@ export function getFirebaseApp(): firebase.app.App {
|
|||
}
|
||||
return firebaseApp;
|
||||
}
|
||||
|
||||
function parseCsv(val: string): string[] {
|
||||
if (!val) return [];
|
||||
|
||||
const regex = /(".*?"|[^",]+)(?=\s*,|\s*$)/g;
|
||||
const matches = val.match(regex) || [];
|
||||
return matches.map(item => item.replace(/^"|"$/g, '').trim());
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ import {
|
|||
blockZoomerOrigins,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
languageFilter,
|
||||
stripHeaders,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
|
@ -142,7 +141,6 @@ const anthropicProxy = createQueueMiddleware({
|
|||
applyQuotaLimits,
|
||||
addKey,
|
||||
addAnthropicPreamble,
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
|
|
|
@ -13,7 +13,6 @@ import {
|
|||
signAwsRequest,
|
||||
finalizeAwsRequest,
|
||||
createOnProxyReqHandler,
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
|
@ -134,7 +133,6 @@ const awsProxy = createQueueMiddleware({
|
|||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeAwsRequest,
|
||||
|
|
|
@ -90,7 +90,7 @@ function classifyError(err: Error): {
|
|||
} & Record<string, any> {
|
||||
const defaultError = {
|
||||
status: 500,
|
||||
userMessage: `Reverse proxy encountered an unexpected error. (${err.message})`,
|
||||
userMessage: `Reverse proxy error: ${err.message}`,
|
||||
type: "proxy_internal_error",
|
||||
stack: err.stack,
|
||||
};
|
||||
|
|
|
@ -12,6 +12,7 @@ export {
|
|||
export { applyQuotaLimits } from "./apply-quota-limits";
|
||||
export { validateContextSize } from "./validate-context-size";
|
||||
export { countPromptTokens } from "./count-prompt-tokens";
|
||||
export { languageFilter } from "./language-filter";
|
||||
export { setApiFormat } from "./set-api-format";
|
||||
export { signAwsRequest } from "./sign-aws-request";
|
||||
export { transformOutboundPayload } from "./transform-outbound-payload";
|
||||
|
@ -22,7 +23,6 @@ export { addAnthropicPreamble } from "./add-anthropic-preamble";
|
|||
export { blockZoomerOrigins } from "./block-zoomer-origins";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
export { finalizeAwsRequest } from "./finalize-aws-request";
|
||||
export { languageFilter } from "./language-filter";
|
||||
export { limitCompletions } from "./limit-completions";
|
||||
export { stripHeaders } from "./strip-headers";
|
||||
|
||||
|
|
|
@ -1,38 +1,49 @@
|
|||
import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { UserInputError } from "../../../shared/errors";
|
||||
|
||||
const DISALLOWED_REGEX =
|
||||
/[\u2E80-\u2E99\u2E9B-\u2EF3\u2F00-\u2FD5\u3005\u3007\u3021-\u3029\u3038-\u303B\u3400-\u4DB5\u4E00-\u9FD5\uF900-\uFA6D\uFA70-\uFAD9]/;
|
||||
const rejectedClients = new Map<string, number>();
|
||||
|
||||
// Our shitty free-tier VMs will fall over if we test every single character in
|
||||
// each 15k character request ten times a second. So we'll just sample 20% of
|
||||
// the characters and hope that's enough.
|
||||
const containsDisallowedCharacters = (text: string) => {
|
||||
const sampleSize = Math.ceil(text.length * 0.2);
|
||||
const sample = text
|
||||
.split("")
|
||||
.sort(() => 0.5 - Math.random())
|
||||
.slice(0, sampleSize)
|
||||
.join("");
|
||||
return DISALLOWED_REGEX.test(sample);
|
||||
};
|
||||
console.log(config.rejectPhrases);
|
||||
|
||||
/** Block requests containing too many disallowed characters. */
|
||||
export const languageFilter: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (!config.rejectDisallowed) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isCompletionRequest(req)) {
|
||||
const combinedText = getPromptFromRequest(req);
|
||||
if (containsDisallowedCharacters(combinedText)) {
|
||||
logger.warn(`Blocked request containing bad characters`);
|
||||
_proxyReq.destroy(new Error(config.rejectMessage));
|
||||
setInterval(() => {
|
||||
rejectedClients.forEach((count, ip) => {
|
||||
if (count > 0) {
|
||||
rejectedClients.set(ip, Math.floor(count / 2));
|
||||
} else {
|
||||
rejectedClients.delete(ip);
|
||||
}
|
||||
});
|
||||
}, 30000);
|
||||
|
||||
/**
|
||||
* Block requests containing blacklisted phrases. Repeated rejections from the
|
||||
* same IP address will be throttled.
|
||||
*/
|
||||
export const languageFilter: RequestPreprocessor = async (req) => {
|
||||
if (!config.rejectPhrases.length) return;
|
||||
|
||||
const prompt = getPromptFromRequest(req);
|
||||
const match = config.rejectPhrases.find((phrase) =>
|
||||
prompt.match(new RegExp(phrase, "i"))
|
||||
);
|
||||
|
||||
if (match) {
|
||||
const ip = req.ip;
|
||||
const rejections = (rejectedClients.get(req.ip) || 0) + 1;
|
||||
const delay = Math.min(60000, Math.pow(2, rejections - 1) * 1000);
|
||||
rejectedClients.set(ip, rejections);
|
||||
req.log.warn(
|
||||
{ match, ip, rejections, delay },
|
||||
"Prompt contains rejected phrase"
|
||||
);
|
||||
await new Promise((resolve) => {
|
||||
req.res!.once("close", resolve);
|
||||
setTimeout(resolve, delay);
|
||||
});
|
||||
throw new UserInputError(config.rejectMessage);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -44,8 +55,10 @@ function getPromptFromRequest(req: Request) {
|
|||
return body.prompt;
|
||||
case "openai":
|
||||
return body.messages
|
||||
.map((m: { content: string }) => m.content)
|
||||
.join("\n");
|
||||
.map(
|
||||
(m: { content: string; role: string }) => `${m.role}: ${m.content}`
|
||||
)
|
||||
.join("\n\n");
|
||||
case "openai-text":
|
||||
return body.prompt;
|
||||
case "google-palm":
|
||||
|
|
|
@ -7,6 +7,7 @@ import {
|
|||
countPromptTokens,
|
||||
setApiFormat,
|
||||
transformOutboundPayload,
|
||||
languageFilter,
|
||||
} from ".";
|
||||
import { ZodIssue } from "zod";
|
||||
|
||||
|
@ -38,6 +39,7 @@ export const createPreprocessorMiddleware = (
|
|||
...(beforeTransform ?? []),
|
||||
transformOutboundPayload,
|
||||
countPromptTokens,
|
||||
languageFilter,
|
||||
...(afterTransform ?? []),
|
||||
validateContextSize,
|
||||
];
|
||||
|
@ -81,7 +83,11 @@ async function executePreprocessors(
|
|||
// stream yet as that is typically done later by the request queue. We'll
|
||||
// do that here and then call classifyErrorAndSend to use the streaming
|
||||
// error handler.
|
||||
initializeSseStream(res);
|
||||
const { stream } = req.body;
|
||||
const isStreaming = stream === "true" || stream === true;
|
||||
if (isStreaming && !res.headersSent) {
|
||||
initializeSseStream(res);
|
||||
}
|
||||
classifyErrorAndSend(error as Error, req, res);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import {
|
|||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
forceModel,
|
||||
languageFilter,
|
||||
limitCompletions,
|
||||
stripHeaders,
|
||||
createOnProxyReqHandler,
|
||||
|
@ -175,7 +174,6 @@ const openaiProxy = createQueueMiddleware({
|
|||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
languageFilter,
|
||||
limitCompletions,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
|
|
|
@ -15,7 +15,6 @@ import {
|
|||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
forceModel,
|
||||
languageFilter,
|
||||
stripHeaders,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
|
@ -155,7 +154,6 @@ const googlePalmProxy = createQueueMiddleware({
|
|||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
|
|
Loading…
Reference in New Issue