diff --git a/src/config.ts b/src/config.ts index 00c07a6..e97fe8a 100644 --- a/src/config.ts +++ b/src/config.ts @@ -65,6 +65,11 @@ type Config = { * management mode is set to 'user_token'. */ adminKey?: string; + /** + * The password required to view the service info/status page. If not set, the + * info page will be publicly accessible. + */ + serviceInfoPassword?: string; /** * Which user management mode to use. * - `none`: No user management. Proxy is open to all requests with basic @@ -259,6 +264,7 @@ export const config: Config = { azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""), adminKey: getEnvWithDefault("ADMIN_KEY", ""), + serviceInfoPassword: getEnvWithDefault("SERVICE_INFO_PASSWORD", ""), gatekeeper: getEnvWithDefault("GATEKEEPER", "none"), gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"), maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0), @@ -435,6 +441,7 @@ export const OMITTED_KEYS = [ "azureCredentials", "proxyKey", "adminKey", + "serviceInfoPassword", "rejectPhrases", "rejectMessage", "showTokenCosts", diff --git a/src/info-page.ts b/src/info-page.ts index 3116992..c91ec57 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -1,12 +1,14 @@ /** This whole module kinda sucks */ import fs from "fs"; -import { Request, Response } from "express"; +import express, { Router, Request, Response } from "express"; import showdown from "showdown"; import { config } from "./config"; import { buildInfo, ServiceInfo } from "./service-info"; import { getLastNImages } from "./shared/file-storage/image-history"; import { keyPool } from "./shared/key-management"; import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models"; +import { withSession } from "./shared/with-session"; +import { checkCsrfToken, injectCsrfToken } from "./shared/inject-csrf"; const INFO_PAGE_TTL = 2000; const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { @@ -203,3 +205,48 @@ function getExternalUrlForHuggingfaceSpaceId(spaceId: string) { return ""; } } + +function checkIfUnlocked(req: Request, res: Response, next: express.NextFunction) { + if (config.serviceInfoPassword?.length && !req.session?.unlocked) { + return res.redirect("/unlock-info"); + } + next(); +} + +const infoPageRouter = Router(); +if (config.serviceInfoPassword?.length) { + infoPageRouter.use( + express.json({ limit: "1mb" }), + express.urlencoded({ extended: true, limit: "1mb" }) + ); + infoPageRouter.use(withSession); + infoPageRouter.use(injectCsrfToken, checkCsrfToken); + infoPageRouter.post( + "/unlock-info", + (req, res) => { + if (req.body.password !== config.serviceInfoPassword) { + return res.status(403).send("Incorrect password"); + } + req.session!.unlocked = true; + res.redirect("/"); + }, + ); + infoPageRouter.get("/unlock-info", (_req, res) => { + if (_req.session?.unlocked) return res.redirect("/"); + + res.send(` +
+ `); + }); + infoPageRouter.use(checkIfUnlocked); +} +infoPageRouter.get("/", handleInfoPage); +infoPageRouter.get("/status", (req, res) => { + res.json(buildInfo(req.protocol + "://" + req.get("host"), false)); +}); +export { infoPageRouter }; diff --git a/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts index 3350851..fad4f90 100644 --- a/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts @@ -1,14 +1,14 @@ -import { - isImageGenerationRequest, - isTextGenerationRequest, -} from "../../common"; -import { RequestPreprocessor } from "../index"; import { openAIToAnthropic } from "../../../../shared/api-schemas/anthropic"; import { openAIToOpenAIText } from "../../../../shared/api-schemas/openai-text"; import { openAIToOpenAIImage } from "../../../../shared/api-schemas/openai-image"; import { openAIToGoogleAI } from "../../../../shared/api-schemas/google-ai"; import { fixMistralPrompt } from "../../../../shared/api-schemas/mistral-ai"; import { API_SCHEMA_VALIDATORS } from "../../../../shared/api-schemas"; +import { + isImageGenerationRequest, + isTextGenerationRequest, +} from "../../common"; +import { RequestPreprocessor } from "../index"; /** Transforms an incoming request body to one that matches the target API. */ export const transformOutboundPayload: RequestPreprocessor = async (req) => { diff --git a/src/server.ts b/src/server.ts index 1f8fdbf..9096de9 100644 --- a/src/server.ts +++ b/src/server.ts @@ -12,14 +12,14 @@ import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir"; import { keyPool } from "./shared/key-management"; import { adminRouter } from "./admin/routes"; import { proxyRouter } from "./proxy/routes"; -import { handleInfoPage } from "./info-page"; +import { infoPageRouter } from "./info-page"; +import { userRouter } from "./user/routes"; import { buildInfo } from "./service-info"; import { logQueue } from "./shared/prompt-logging"; import { start as startRequestQueue } from "./proxy/queue"; import { init as initUserStore } from "./shared/users/user-store"; import { init as initTokenizers } from "./shared/tokenization"; import { checkOrigin } from "./proxy/check-origin"; -import { userRouter } from "./user/routes"; const PORT = config.port; const BIND_ADDRESS = config.bindAddress; @@ -69,11 +69,8 @@ app.use(checkOrigin); if (config.staticServiceInfo) { app.get("/", (_req, res) => res.sendStatus(200)); } else { - app.get("/", handleInfoPage); + app.use("/", infoPageRouter); } -app.get("/status", (req, res) => { - res.json(buildInfo(req.protocol + "://" + req.get("host"), false)); -}); app.use("/admin", adminRouter); app.use("/proxy", proxyRouter); app.use("/user", userRouter); diff --git a/src/shared/custom.d.ts b/src/shared/custom.d.ts index 1f98986..8f91644 100644 --- a/src/shared/custom.d.ts +++ b/src/shared/custom.d.ts @@ -41,5 +41,6 @@ declare module "express-session" { userToken?: string; csrf?: string; flash?: { type: string; message: string }; + unlocked?: boolean; } }