diff --git a/.env.example b/.env.example index 324a8f4..8357a18 100644 --- a/.env.example +++ b/.env.example @@ -43,8 +43,10 @@ # Destination to redirect blocked requests to. # BLOCK_REDIRECT="https://roblox.com/" -# Whether to reject requests containing disallowed content. -# REJECT_DISALLOWED=false +# Comma-separated list of phrases that will be rejected. Only whole words are matched. +# Surround phrases with quotes if they contain commas. +# Avoid short or common phrases as this tests the entire prompt. +# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four" # Message to show when requests are rejected. # REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy." diff --git a/src/config.ts b/src/config.ts index 3d4ccda..72f555b 100644 --- a/src/config.ts +++ b/src/config.ts @@ -95,10 +95,10 @@ type Config = { maxOutputTokensOpenAI: number; /** For Anthropic, the maximum number of sampled tokens a user can request. */ maxOutputTokensAnthropic: number; - /** Whether requests containing disallowed characters should be rejected. */ - rejectDisallowed?: boolean; + /** Whether requests containing the following phrases should be rejected. */ + rejectPhrases: string[]; /** Message to return when rejecting requests. */ - rejectMessage?: string; + rejectMessage: string; /** Verbosity level of diagnostic logging. */ logLevel: "trace" | "debug" | "info" | "warn" | "error"; /** @@ -203,7 +203,7 @@ export const config: Config = { "bison", "aws-claude", ]), - rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false), + rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")), rejectMessage: getEnvWithDefault( "REJECT_MESSAGE", "This content violates /aicg/'s acceptable use policy." @@ -321,6 +321,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [ "proxyKey", "adminKey", "checkKeys", + "rejectPhrases", "showTokenCosts", "googleSheetsKey", "firebaseKey", @@ -421,3 +422,11 @@ export function getFirebaseApp(): firebase.app.App { } return firebaseApp; } + +function parseCsv(val: string): string[] { + if (!val) return []; + + const regex = /(".*?"|[^",]+)(?=\s*,|\s*$)/g; + const matches = val.match(regex) || []; + return matches.map(item => item.replace(/^"|"$/g, '').trim()); +} diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 6365fda..98cc98b 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -12,7 +12,6 @@ import { blockZoomerOrigins, createPreprocessorMiddleware, finalizeBody, - languageFilter, stripHeaders, createOnProxyReqHandler, } from "./middleware/request"; @@ -142,7 +141,6 @@ const anthropicProxy = createQueueMiddleware({ applyQuotaLimits, addKey, addAnthropicPreamble, - languageFilter, blockZoomerOrigins, stripHeaders, finalizeBody, diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index e3318d6..c80aa6a 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -13,7 +13,6 @@ import { signAwsRequest, finalizeAwsRequest, createOnProxyReqHandler, - languageFilter, blockZoomerOrigins, } from "./middleware/request"; import { @@ -134,7 +133,6 @@ const awsProxy = createQueueMiddleware({ proxyReq: createOnProxyReqHandler({ pipeline: [ applyQuotaLimits, - languageFilter, blockZoomerOrigins, stripHeaders, finalizeAwsRequest, diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 3afff87..99f8281 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -90,7 +90,7 @@ function classifyError(err: Error): { } & Record { const defaultError = { status: 500, - userMessage: `Reverse proxy encountered an unexpected error. (${err.message})`, + userMessage: `Reverse proxy error: ${err.message}`, type: "proxy_internal_error", stack: err.stack, }; diff --git a/src/proxy/middleware/request/index.ts b/src/proxy/middleware/request/index.ts index 2b39ab7..565b9ac 100644 --- a/src/proxy/middleware/request/index.ts +++ b/src/proxy/middleware/request/index.ts @@ -12,6 +12,7 @@ export { export { applyQuotaLimits } from "./apply-quota-limits"; export { validateContextSize } from "./validate-context-size"; export { countPromptTokens } from "./count-prompt-tokens"; +export { languageFilter } from "./language-filter"; export { setApiFormat } from "./set-api-format"; export { signAwsRequest } from "./sign-aws-request"; export { transformOutboundPayload } from "./transform-outbound-payload"; @@ -22,7 +23,6 @@ export { addAnthropicPreamble } from "./add-anthropic-preamble"; export { blockZoomerOrigins } from "./block-zoomer-origins"; export { finalizeBody } from "./finalize-body"; export { finalizeAwsRequest } from "./finalize-aws-request"; -export { languageFilter } from "./language-filter"; export { limitCompletions } from "./limit-completions"; export { stripHeaders } from "./strip-headers"; diff --git a/src/proxy/middleware/request/language-filter.ts b/src/proxy/middleware/request/language-filter.ts index 947985e..4699b9e 100644 --- a/src/proxy/middleware/request/language-filter.ts +++ b/src/proxy/middleware/request/language-filter.ts @@ -1,38 +1,49 @@ import { Request } from "express"; import { config } from "../../../config"; -import { logger } from "../../../logger"; import { assertNever } from "../../../shared/utils"; -import { isCompletionRequest } from "../common"; -import { ProxyRequestMiddleware } from "."; +import { RequestPreprocessor } from "."; +import { UserInputError } from "../../../shared/errors"; -const DISALLOWED_REGEX = - /[\u2E80-\u2E99\u2E9B-\u2EF3\u2F00-\u2FD5\u3005\u3007\u3021-\u3029\u3038-\u303B\u3400-\u4DB5\u4E00-\u9FD5\uF900-\uFA6D\uFA70-\uFAD9]/; +const rejectedClients = new Map(); -// Our shitty free-tier VMs will fall over if we test every single character in -// each 15k character request ten times a second. So we'll just sample 20% of -// the characters and hope that's enough. -const containsDisallowedCharacters = (text: string) => { - const sampleSize = Math.ceil(text.length * 0.2); - const sample = text - .split("") - .sort(() => 0.5 - Math.random()) - .slice(0, sampleSize) - .join(""); - return DISALLOWED_REGEX.test(sample); -}; +console.log(config.rejectPhrases); -/** Block requests containing too many disallowed characters. */ -export const languageFilter: ProxyRequestMiddleware = (_proxyReq, req) => { - if (!config.rejectDisallowed) { - return; - } - - if (isCompletionRequest(req)) { - const combinedText = getPromptFromRequest(req); - if (containsDisallowedCharacters(combinedText)) { - logger.warn(`Blocked request containing bad characters`); - _proxyReq.destroy(new Error(config.rejectMessage)); +setInterval(() => { + rejectedClients.forEach((count, ip) => { + if (count > 0) { + rejectedClients.set(ip, Math.floor(count / 2)); + } else { + rejectedClients.delete(ip); } + }); +}, 30000); + +/** + * Block requests containing blacklisted phrases. Repeated rejections from the + * same IP address will be throttled. + */ +export const languageFilter: RequestPreprocessor = async (req) => { + if (!config.rejectPhrases.length) return; + + const prompt = getPromptFromRequest(req); + const match = config.rejectPhrases.find((phrase) => + prompt.match(new RegExp(phrase, "i")) + ); + + if (match) { + const ip = req.ip; + const rejections = (rejectedClients.get(req.ip) || 0) + 1; + const delay = Math.min(60000, Math.pow(2, rejections - 1) * 1000); + rejectedClients.set(ip, rejections); + req.log.warn( + { match, ip, rejections, delay }, + "Prompt contains rejected phrase" + ); + await new Promise((resolve) => { + req.res!.once("close", resolve); + setTimeout(resolve, delay); + }); + throw new UserInputError(config.rejectMessage); } }; @@ -44,8 +55,10 @@ function getPromptFromRequest(req: Request) { return body.prompt; case "openai": return body.messages - .map((m: { content: string }) => m.content) - .join("\n"); + .map( + (m: { content: string; role: string }) => `${m.role}: ${m.content}` + ) + .join("\n\n"); case "openai-text": return body.prompt; case "google-palm": diff --git a/src/proxy/middleware/request/preprocess.ts b/src/proxy/middleware/request/preprocess.ts index 347bb14..be515a8 100644 --- a/src/proxy/middleware/request/preprocess.ts +++ b/src/proxy/middleware/request/preprocess.ts @@ -7,6 +7,7 @@ import { countPromptTokens, setApiFormat, transformOutboundPayload, + languageFilter, } from "."; import { ZodIssue } from "zod"; @@ -38,6 +39,7 @@ export const createPreprocessorMiddleware = ( ...(beforeTransform ?? []), transformOutboundPayload, countPromptTokens, + languageFilter, ...(afterTransform ?? []), validateContextSize, ]; @@ -81,7 +83,11 @@ async function executePreprocessors( // stream yet as that is typically done later by the request queue. We'll // do that here and then call classifyErrorAndSend to use the streaming // error handler. - initializeSseStream(res); + const { stream } = req.body; + const isStreaming = stream === "true" || stream === true; + if (isStreaming && !res.headersSent) { + initializeSseStream(res); + } classifyErrorAndSend(error as Error, req, res); } } diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index e3a45fa..68f78ce 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -21,7 +21,6 @@ import { createPreprocessorMiddleware, finalizeBody, forceModel, - languageFilter, limitCompletions, stripHeaders, createOnProxyReqHandler, @@ -175,7 +174,6 @@ const openaiProxy = createQueueMiddleware({ pipeline: [ applyQuotaLimits, addKey, - languageFilter, limitCompletions, blockZoomerOrigins, stripHeaders, diff --git a/src/proxy/palm.ts b/src/proxy/palm.ts index 1f3887b..0c9d67b 100644 --- a/src/proxy/palm.ts +++ b/src/proxy/palm.ts @@ -15,7 +15,6 @@ import { createPreprocessorMiddleware, finalizeBody, forceModel, - languageFilter, stripHeaders, } from "./middleware/request"; import { @@ -155,7 +154,6 @@ const googlePalmProxy = createQueueMiddleware({ pipeline: [ applyQuotaLimits, addKey, - languageFilter, blockZoomerOrigins, stripHeaders, finalizeBody,