adds SERVICE_INFO_PASSWORD to gate infopage behind a password

This commit is contained in:
nai-degen 2024-02-04 14:04:46 -06:00
parent 235510e588
commit fe429a7610
5 changed files with 64 additions and 12 deletions

View File

@ -65,6 +65,11 @@ type Config = {
* management mode is set to 'user_token'. * management mode is set to 'user_token'.
*/ */
adminKey?: string; adminKey?: string;
/**
* The password required to view the service info/status page. If not set, the
* info page will be publicly accessible.
*/
serviceInfoPassword?: string;
/** /**
* Which user management mode to use. * Which user management mode to use.
* - `none`: No user management. Proxy is open to all requests with basic * - `none`: No user management. Proxy is open to all requests with basic
@ -259,6 +264,7 @@ export const config: Config = {
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""), adminKey: getEnvWithDefault("ADMIN_KEY", ""),
serviceInfoPassword: getEnvWithDefault("SERVICE_INFO_PASSWORD", ""),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"), gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"), gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"),
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0), maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
@ -435,6 +441,7 @@ export const OMITTED_KEYS = [
"azureCredentials", "azureCredentials",
"proxyKey", "proxyKey",
"adminKey", "adminKey",
"serviceInfoPassword",
"rejectPhrases", "rejectPhrases",
"rejectMessage", "rejectMessage",
"showTokenCosts", "showTokenCosts",

View File

@ -1,12 +1,14 @@
/** This whole module kinda sucks */ /** This whole module kinda sucks */
import fs from "fs"; import fs from "fs";
import { Request, Response } from "express"; import express, { Router, Request, Response } from "express";
import showdown from "showdown"; import showdown from "showdown";
import { config } from "./config"; import { config } from "./config";
import { buildInfo, ServiceInfo } from "./service-info"; import { buildInfo, ServiceInfo } from "./service-info";
import { getLastNImages } from "./shared/file-storage/image-history"; import { getLastNImages } from "./shared/file-storage/image-history";
import { keyPool } from "./shared/key-management"; import { keyPool } from "./shared/key-management";
import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models"; import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models";
import { withSession } from "./shared/with-session";
import { checkCsrfToken, injectCsrfToken } from "./shared/inject-csrf";
const INFO_PAGE_TTL = 2000; const INFO_PAGE_TTL = 2000;
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
@ -203,3 +205,48 @@ function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
return ""; return "";
} }
} }
function checkIfUnlocked(req: Request, res: Response, next: express.NextFunction) {
if (config.serviceInfoPassword?.length && !req.session?.unlocked) {
return res.redirect("/unlock-info");
}
next();
}
const infoPageRouter = Router();
if (config.serviceInfoPassword?.length) {
infoPageRouter.use(
express.json({ limit: "1mb" }),
express.urlencoded({ extended: true, limit: "1mb" })
);
infoPageRouter.use(withSession);
infoPageRouter.use(injectCsrfToken, checkCsrfToken);
infoPageRouter.post(
"/unlock-info",
(req, res) => {
if (req.body.password !== config.serviceInfoPassword) {
return res.status(403).send("Incorrect password");
}
req.session!.unlocked = true;
res.redirect("/");
},
);
infoPageRouter.get("/unlock-info", (_req, res) => {
if (_req.session?.unlocked) return res.redirect("/");
res.send(`
<form method="post" action="/unlock-info">
<h1>Unlock Service Info</h1>
<input type="hidden" name="_csrf" value="${res.locals.csrfToken}" />
<input type="password" name="password" placeholder="Password" />
<button type="submit">Unlock</button>
</form>
`);
});
infoPageRouter.use(checkIfUnlocked);
}
infoPageRouter.get("/", handleInfoPage);
infoPageRouter.get("/status", (req, res) => {
res.json(buildInfo(req.protocol + "://" + req.get("host"), false));
});
export { infoPageRouter };

View File

@ -1,14 +1,14 @@
import {
isImageGenerationRequest,
isTextGenerationRequest,
} from "../../common";
import { RequestPreprocessor } from "../index";
import { openAIToAnthropic } from "../../../../shared/api-schemas/anthropic"; import { openAIToAnthropic } from "../../../../shared/api-schemas/anthropic";
import { openAIToOpenAIText } from "../../../../shared/api-schemas/openai-text"; import { openAIToOpenAIText } from "../../../../shared/api-schemas/openai-text";
import { openAIToOpenAIImage } from "../../../../shared/api-schemas/openai-image"; import { openAIToOpenAIImage } from "../../../../shared/api-schemas/openai-image";
import { openAIToGoogleAI } from "../../../../shared/api-schemas/google-ai"; import { openAIToGoogleAI } from "../../../../shared/api-schemas/google-ai";
import { fixMistralPrompt } from "../../../../shared/api-schemas/mistral-ai"; import { fixMistralPrompt } from "../../../../shared/api-schemas/mistral-ai";
import { API_SCHEMA_VALIDATORS } from "../../../../shared/api-schemas"; import { API_SCHEMA_VALIDATORS } from "../../../../shared/api-schemas";
import {
isImageGenerationRequest,
isTextGenerationRequest,
} from "../../common";
import { RequestPreprocessor } from "../index";
/** Transforms an incoming request body to one that matches the target API. */ /** Transforms an incoming request body to one that matches the target API. */
export const transformOutboundPayload: RequestPreprocessor = async (req) => { export const transformOutboundPayload: RequestPreprocessor = async (req) => {

View File

@ -12,14 +12,14 @@ import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir";
import { keyPool } from "./shared/key-management"; import { keyPool } from "./shared/key-management";
import { adminRouter } from "./admin/routes"; import { adminRouter } from "./admin/routes";
import { proxyRouter } from "./proxy/routes"; import { proxyRouter } from "./proxy/routes";
import { handleInfoPage } from "./info-page"; import { infoPageRouter } from "./info-page";
import { userRouter } from "./user/routes";
import { buildInfo } from "./service-info"; import { buildInfo } from "./service-info";
import { logQueue } from "./shared/prompt-logging"; import { logQueue } from "./shared/prompt-logging";
import { start as startRequestQueue } from "./proxy/queue"; import { start as startRequestQueue } from "./proxy/queue";
import { init as initUserStore } from "./shared/users/user-store"; import { init as initUserStore } from "./shared/users/user-store";
import { init as initTokenizers } from "./shared/tokenization"; import { init as initTokenizers } from "./shared/tokenization";
import { checkOrigin } from "./proxy/check-origin"; import { checkOrigin } from "./proxy/check-origin";
import { userRouter } from "./user/routes";
const PORT = config.port; const PORT = config.port;
const BIND_ADDRESS = config.bindAddress; const BIND_ADDRESS = config.bindAddress;
@ -69,11 +69,8 @@ app.use(checkOrigin);
if (config.staticServiceInfo) { if (config.staticServiceInfo) {
app.get("/", (_req, res) => res.sendStatus(200)); app.get("/", (_req, res) => res.sendStatus(200));
} else { } else {
app.get("/", handleInfoPage); app.use("/", infoPageRouter);
} }
app.get("/status", (req, res) => {
res.json(buildInfo(req.protocol + "://" + req.get("host"), false));
});
app.use("/admin", adminRouter); app.use("/admin", adminRouter);
app.use("/proxy", proxyRouter); app.use("/proxy", proxyRouter);
app.use("/user", userRouter); app.use("/user", userRouter);

View File

@ -41,5 +41,6 @@ declare module "express-session" {
userToken?: string; userToken?: string;
csrf?: string; csrf?: string;
flash?: { type: string; message: string }; flash?: { type: string; message: string };
unlocked?: boolean;
} }
} }