oai-reverse-proxy/src/proxy/middleware/request/language-filter.ts

69 lines
1.8 KiB
TypeScript

import { Request } from "express";
import { config } from "../../../config";
import { assertNever } from "../../../shared/utils";
import { RequestPreprocessor } from ".";
import { UserInputError } from "../../../shared/errors";
const rejectedClients = new Map<string, number>();
setInterval(() => {
rejectedClients.forEach((count, ip) => {
if (count > 0) {
rejectedClients.set(ip, Math.floor(count / 2));
} else {
rejectedClients.delete(ip);
}
});
}, 30000);
/**
* Block requests containing blacklisted phrases. Repeated rejections from the
* same IP address will be throttled.
*/
export const languageFilter: RequestPreprocessor = async (req) => {
if (!config.rejectPhrases.length) return;
const prompt = getPromptFromRequest(req);
const match = config.rejectPhrases.find((phrase) =>
prompt.match(new RegExp(phrase, "i"))
);
if (match) {
const ip = req.ip;
const rejections = (rejectedClients.get(req.ip) || 0) + 1;
const delay = Math.min(60000, Math.pow(2, rejections - 1) * 1000);
rejectedClients.set(ip, rejections);
req.log.warn(
{ match, ip, rejections, delay },
"Prompt contains rejected phrase"
);
await new Promise((resolve) => {
req.res!.once("close", resolve);
setTimeout(resolve, delay);
});
throw new UserInputError(config.rejectMessage);
}
};
function getPromptFromRequest(req: Request) {
const service = req.outboundApi;
const body = req.body;
switch (service) {
case "anthropic":
return body.prompt;
case "openai":
return body.messages
.map(
(m: { content: string; role: string }) => `${m.role}: ${m.content}`
)
.join("\n\n");
case "openai-text":
case "openai-image":
return body.prompt;
case "google-palm":
return body.prompt.text;
default:
assertNever(service);
}
}