Prefer user tokens as rate-limit/queue keys when available (khanon/oai-reverse-proxy!10)

This commit is contained in:
nai-degen 2023-05-19 04:33:20 +00:00
parent dfd8f0cc97
commit 2bad644772
4 changed files with 42 additions and 18 deletions

View File

@ -81,7 +81,7 @@ function cacheInfoPageHtml(host: string) {
...getQueueInformation(), ...getQueueInformation(),
keys: keyInfo, keys: keyInfo,
config: listConfig(), config: listConfig(),
build: process.env.COMMIT_SHA || "dev", build: process.env.BUILD_INFO || "dev",
}; };
const title = getServerTitle(); const title = getServerTitle();

View File

@ -31,16 +31,33 @@ const AGNAI_CONCURRENCY_LIMIT = 15;
/** Maximum number of queue slots for individual users. */ /** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = 1; const USER_CONCURRENCY_LIMIT = 1;
const sameIpPredicate = (incoming: Request) => (queued: Request) =>
queued.ip === incoming.ip;
const sameUserPredicate = (incoming: Request) => (queued: Request) => {
const incomingUser = incoming.user ?? { token: incoming.ip };
const queuedUser = queued.user ?? { token: queued.ip };
return queuedUser.token === incomingUser.token;
};
export function enqueue(req: Request) { export function enqueue(req: Request) {
// All agnai.chat requests come from the same IP, so we allow them to have let enqueuedRequestCount = 0;
let isGuest = req.user?.token === undefined;
if (isGuest) {
enqueuedRequestCount = queue.filter(sameIpPredicate(req)).length;
} else {
enqueuedRequestCount = queue.filter(sameUserPredicate(req)).length;
}
// All Agnai.chat requests come from the same IP, so we allow them to have
// more spots in the queue. Can't make it unlimited because people will // more spots in the queue. Can't make it unlimited because people will
// intentionally abuse it. // intentionally abuse it.
// Authenticated users always get a single spot in the queue.
const maxConcurrentQueuedRequests = const maxConcurrentQueuedRequests =
req.ip === AGNAI_DOT_CHAT_IP isGuest && req.ip === AGNAI_DOT_CHAT_IP
? AGNAI_CONCURRENCY_LIMIT ? AGNAI_CONCURRENCY_LIMIT
: USER_CONCURRENCY_LIMIT; : USER_CONCURRENCY_LIMIT;
const reqCount = queue.filter((r) => r.ip === req.ip).length; if (enqueuedRequestCount >= maxConcurrentQueuedRequests) {
if (reqCount >= maxConcurrentQueuedRequests) {
if (req.ip === AGNAI_DOT_CHAT_IP) { if (req.ip === AGNAI_DOT_CHAT_IP) {
// Re-enqueued requests are not counted towards the limit since they // Re-enqueued requests are not counted towards the limit since they
// already made it through the queue once. // already made it through the queue once.
@ -48,7 +65,7 @@ export function enqueue(req: Request) {
throw new Error("Too many agnai.chat requests are already queued"); throw new Error("Too many agnai.chat requests are already queued");
} }
} else { } else {
throw new Error("Request is already queued for this IP"); throw new Error("Your IP or token already has a request in the queue");
} }
} }

View File

@ -66,12 +66,16 @@ export const ipLimiter = (req: Request, res: Response, next: NextFunction) => {
return; return;
} }
const { remaining, reset } = getStatus(req.ip); // If user is authenticated, key rate limiting by their token. Otherwise, key
// rate limiting by their IP address. Mitigates key sharing.
const rateLimitKey = req.user?.token || req.ip;
const { remaining, reset } = getStatus(rateLimitKey);
res.set("X-RateLimit-Limit", config.modelRateLimit.toString()); res.set("X-RateLimit-Limit", config.modelRateLimit.toString());
res.set("X-RateLimit-Remaining", remaining.toString()); res.set("X-RateLimit-Remaining", remaining.toString());
res.set("X-RateLimit-Reset", reset.toString()); res.set("X-RateLimit-Reset", reset.toString());
const tryAgainInMs = getTryAgainInMs(req.ip); const tryAgainInMs = getTryAgainInMs(rateLimitKey);
if (tryAgainInMs > 0) { if (tryAgainInMs > 0) {
res.set("Retry-After", tryAgainInMs.toString()); res.set("Retry-After", tryAgainInMs.toString());
res.status(429).json({ res.status(429).json({

View File

@ -35,6 +35,8 @@ app.use(
"req.headers.authorization", "req.headers.authorization",
'req.headers["x-forwarded-for"]', 'req.headers["x-forwarded-for"]',
'req.headers["x-real-ip"]', 'req.headers["x-real-ip"]',
'req.headers["true-client-ip"]',
'req.headers["cf-connecting-ip"]',
], ],
censor: "********", censor: "********",
}, },
@ -85,7 +87,7 @@ app.use((_req: unknown, res: express.Response) => {
async function start() { async function start() {
logger.info("Server starting up..."); logger.info("Server starting up...");
setGitSha(); setBuildInfo();
logger.info("Checking configs and external dependencies..."); logger.info("Checking configs and external dependencies...");
await assertConfigIsValid(); await assertConfigIsValid();
@ -112,7 +114,7 @@ async function start() {
}); });
logger.info( logger.info(
{ sha: process.env.COMMIT_SHA, nodeEnv: process.env.NODE_ENV }, { build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV },
"Startup complete." "Startup complete."
); );
} }
@ -132,15 +134,16 @@ function registerUncaughtExceptionHandler() {
}); });
} }
function setGitSha() { function setBuildInfo() {
// On Render, the .git directory isn't available in the docker build context // On Render, the .git directory isn't available in the docker build context
// so we can't get the SHA directly, but they expose it as an env variable. // so we can't get the SHA directly, but they expose it as an env variable.
if (process.env.RENDER) { if (process.env.RENDER) {
const shaString = `${process.env.RENDER_GIT_COMMIT?.slice(0, 7)} (${ const sha = process.env.RENDER_GIT_COMMIT?.slice(0, 7) || "unknown SHA";
process.env.RENDER_GIT_REPO_SLUG const branch = process.env.RENDER_GIT_BRANCH || "unknown branch";
})`; const repo = process.env.RENDER_GIT_REPO_SLUG || "unknown repo";
process.env.COMMIT_SHA = shaString; const buildInfo = `${sha} (${branch}@${repo})`;
logger.info({ sha: shaString }, "Got commit SHA via Render config."); process.env.BUILD_INFO = buildInfo;
logger.info({ build: buildInfo }, "Got build info from Render config.");
return; return;
} }
@ -171,7 +174,7 @@ function setGitSha() {
logger.info({ sha, status, changes }, "Got commit SHA and status."); logger.info({ sha, status, changes }, "Got commit SHA and status.");
process.env.COMMIT_SHA = `${sha}${changes ? " (modified)" : ""}`; process.env.BUILD_INFO = `${sha}${changes ? " (modified)" : ""}`;
} catch (error: any) { } catch (error: any) {
logger.error( logger.error(
{ {
@ -182,7 +185,7 @@ function setGitSha() {
"Failed to get commit SHA.", "Failed to get commit SHA.",
error error
); );
process.env.COMMIT_SHA = "unknown"; process.env.BUILD_INFO = "unknown";
} }
} }