Sqlite backend with user event logging (khanon/oai-reverse-proxy!69)

This commit is contained in:
scrappyanon 2024-05-26 17:31:12 +00:00 committed by khanon
parent 6352df5d5a
commit 2d82e55d72
17 changed files with 634 additions and 6 deletions

34
package-lock.json generated
View File

@ -19,6 +19,7 @@
"@smithy/types": "^2.10.1",
"@smithy/util-utf8": "^2.1.1",
"axios": "^1.3.5",
"better-sqlite3": "^10.0.0",
"check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
"copyfiles": "^2.4.1",
@ -50,6 +51,7 @@
"zod-error": "^1.5.0"
},
"devDependencies": {
"@types/better-sqlite3": "^7.6.10",
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
"@types/express": "^4.17.17",
@ -1498,6 +1500,15 @@
"tslib": "^2.4.0"
}
},
"node_modules/@types/better-sqlite3": {
"version": "7.6.10",
"resolved": "https://registry.npmjs.org/@types/better-sqlite3/-/better-sqlite3-7.6.10.tgz",
"integrity": "sha512-TZBjD+yOsyrUJGmcUj6OS3JADk3+UZcNv3NOBqGkM09bZdi28fNZw8ODqbMOLfKCu7RYCO62/ldq1iHbzxqoPw==",
"dev": true,
"dependencies": {
"@types/node": "*"
}
},
"node_modules/@types/body-parser": {
"version": "1.19.2",
"resolved": "https://registry.npmjs.org/@types/body-parser/-/body-parser-1.19.2.tgz",
@ -1917,6 +1928,16 @@
}
]
},
"node_modules/better-sqlite3": {
"version": "10.0.0",
"resolved": "https://registry.npmjs.org/better-sqlite3/-/better-sqlite3-10.0.0.tgz",
"integrity": "sha512-rOz0JY8bt9oMgrFssP7GnvA5R3yln73y/NizzWqy3WlFth8Ux8+g4r/N9fjX97nn4X1YX6MTER2doNpTu5pqiA==",
"hasInstallScript": true,
"dependencies": {
"bindings": "^1.5.0",
"prebuild-install": "^7.1.1"
}
},
"node_modules/bignumber.js": {
"version": "9.1.1",
"resolved": "https://registry.npmjs.org/bignumber.js/-/bignumber.js-9.1.1.tgz",
@ -1934,6 +1955,14 @@
"node": ">=8"
}
},
"node_modules/bindings": {
"version": "1.5.0",
"resolved": "https://registry.npmjs.org/bindings/-/bindings-1.5.0.tgz",
"integrity": "sha512-p2q/t/mhvuOj/UeLlV6566GD/guowlr0hHxClI0W9m7MWYkL1F0hLo+0Aexs9HSPCtR1SXQ0TD3MMKrXZajbiQ==",
"dependencies": {
"file-uri-to-path": "1.0.0"
}
},
"node_modules/bl": {
"version": "4.1.0",
"resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz",
@ -3054,6 +3083,11 @@
"node": ">=0.8.0"
}
},
"node_modules/file-uri-to-path": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/file-uri-to-path/-/file-uri-to-path-1.0.0.tgz",
"integrity": "sha512-0Zt+s3L7Vf1biwWZ29aARiVYLx7iMGnEUl9x33fbB/j3jR81u/O2LbqK+Bm1CDSNDKVtJ/YjwY7TUd5SkeLQLw=="
},
"node_modules/filelist": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/filelist/-/filelist-1.0.4.tgz",

View File

@ -4,6 +4,7 @@
"description": "Reverse proxy for the OpenAI API",
"scripts": {
"build": "tsc && copyfiles -u 1 src/**/*.ejs build",
"database:migrate": "ts-node scripts/migrate.ts",
"prepare": "husky install",
"start": "node build/server.js",
"start:dev": "nodemon --watch src --exec ts-node --transpile-only src/server.ts",
@ -27,6 +28,7 @@
"@smithy/types": "^2.10.1",
"@smithy/util-utf8": "^2.1.1",
"axios": "^1.3.5",
"better-sqlite3": "^10.0.0",
"check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
"copyfiles": "^2.4.1",
@ -58,6 +60,7 @@
"zod-error": "^1.5.0"
},
"devDependencies": {
"@types/better-sqlite3": "^7.6.10",
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
"@types/express": "^4.17.17",

39
scripts/migrate.ts Normal file
View File

@ -0,0 +1,39 @@
import Database from "better-sqlite3";
import { DATABASE_VERSION, migrateDatabase } from "../src/shared/database";
import { logger } from "../src/logger";
import { config } from "../src/config";
const log = logger.child({ module: "scripts/migrate" });
async function runMigration() {
let targetVersion = Number(process.argv[2]) || undefined;
if (!targetVersion) {
log.info("Enter target version or leave empty to use the latest version.");
process.stdin.resume();
process.stdin.setEncoding("utf8");
const input = await new Promise<string>((resolve) => {
process.stdin.on("data", (text) => {
resolve((String(text) || "").trim());
});
});
process.stdin.pause();
targetVersion = Number(input);
if (!targetVersion) {
targetVersion = DATABASE_VERSION;
}
}
const db = new Database(config.sqliteDataPath, {
verbose: (msg, ...args) => log.debug({ args }, String(msg)),
});
const currentVersion = db.pragma("user_version", { simple: true });
log.info({ currentVersion, targetVersion }, "Running migrations.");
migrateDatabase(targetVersion, db);
}
runMigration().catch((error) => {
log.error(error, "Migration failed.");
process.exit(1);
});

100
scripts/seed-events.ts Normal file
View File

@ -0,0 +1,100 @@
import Database from "better-sqlite3";
import { v4 as uuidv4 } from "uuid";
import { config } from "../src/config";
function generateRandomIP() {
return (
Math.floor(Math.random() * 255) +
"." +
Math.floor(Math.random() * 255) +
"." +
Math.floor(Math.random() * 255) +
"." +
Math.floor(Math.random() * 255)
);
}
function generateRandomDate() {
const end = new Date();
const start = new Date(end);
start.setDate(end.getDate() - 90);
const randomDate = new Date(
start.getTime() + Math.random() * (end.getTime() - start.getTime())
);
return randomDate.toISOString();
}
function generateMockSHA256() {
const characters = 'abcdef0123456789';
let hash = '';
for (let i = 0; i < 64; i++) {
const randomIndex = Math.floor(Math.random() * characters.length);
hash += characters[randomIndex];
}
return hash;
}
function getRandomModelFamily() {
const modelFamilies = [
"turbo",
"gpt4",
"gpt4-32k",
"gpt4-turbo",
"claude",
"claude-opus",
"gemini-pro",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"mistral-large",
"aws-claude",
"aws-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
"azure-gpt4-turbo",
"dall-e",
"azure-dall-e",
];
return modelFamilies[Math.floor(Math.random() * modelFamilies.length)];
}
(async () => {
const db = new Database(config.sqliteDataPath);
const numRows = 100;
const insertStatement = db.prepare(`
INSERT INTO events (type, ip, date, model, family, hashes, userToken, inputTokens, outputTokens)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`);
const users = Array.from({ length: 10 }, () => uuidv4());
function getRandomUser() {
return users[Math.floor(Math.random() * users.length)];
}
const transaction = db.transaction(() => {
for (let i = 0; i < numRows; i++) {
insertStatement.run(
"chat_completion",
generateRandomIP(),
generateRandomDate(),
getRandomModelFamily() + "-" + Math.floor(Math.random() * 100),
getRandomModelFamily(),
Array.from(
{ length: Math.floor(Math.random() * 10) },
generateMockSHA256
).join(","),
getRandomUser(),
Math.floor(Math.random() * 500),
Math.floor(Math.random() * 6000)
);
}
});
transaction();
console.log(`Inserted ${numRows} rows into the events table.`);
db.close();
})();

49
src/admin/api/events.ts Normal file
View File

@ -0,0 +1,49 @@
import { Router } from "express";
import { z } from "zod";
import { encodeCursor, decodeCursor } from "../../shared/utils";
import { eventsRepo } from "../../shared/database/repos/event";
const router = Router();
/**
* Returns events for the given user token.
* GET /admin/events/:token
* @query first - The number of events to return.
* @query after - The cursor to start returning events from (exclusive).
*/
router.get("/:token", (req, res) => {
const schema = z.object({
token: z.string(),
first: z.coerce.number().int().positive().max(200).default(25),
after: z
.string()
.optional()
.transform((v) => {
try {
return decodeCursor(v);
} catch {
return null;
}
})
.nullable(),
sort: z.string().optional(),
});
const args = schema.safeParse({ ...req.params, ...req.query });
if (!args.success) {
return res.status(400).json({ error: args.error });
}
const data = eventsRepo
.getUserEvents(args.data.token, {
limit: args.data.first,
cursor: args.data.after,
})
.map((e) => ({ node: e, cursor: encodeCursor(e.date) }));
res.json({
data,
endCursor: data[data.length - 1]?.cursor,
});
});
export { router as eventsApiRouter };

View File

@ -9,7 +9,8 @@ import { renderPage } from "../info-page";
import { buildInfo } from "../service-info";
import { authorize } from "./auth";
import { loginRouter } from "./login";
import { usersApiRouter as apiRouter } from "./api/users";
import { eventsApiRouter } from "./api/events";
import { usersApiRouter } from "./api/users";
import { usersWebRouter as webRouter } from "./web/manage";
import { logger } from "../logger";
@ -32,7 +33,8 @@ adminRouter.use(
adminRouter.use(withSession);
adminRouter.use(injectCsrfToken);
adminRouter.use("/users", authorize({ via: "header" }), apiRouter);
adminRouter.use("/users", authorize({ via: "header" }), usersApiRouter);
adminRouter.use("/events", authorize({ via: "header" }), eventsApiRouter);
adminRouter.use(checkCsrfToken);
adminRouter.use(injectLocals);

View File

@ -208,6 +208,32 @@ type Config = {
* key and can't attach the policy, you can set this to true.
*/
allowAwsLogging?: boolean;
/**
* Path to the SQLite database file for storing data such as event logs. By
* default, the database will be stored at `data/database.sqlite`.
*
* Ensure target is writable by the server process, and be careful not to
* select a path that is served publicly. The default path is safe.
*/
sqliteDataPath?: string;
/**
* Whether to log events, such as generated completions, to the database.
* Events are associated with IP+user token pairs. If user_token mode is
* disabled, no events will be logged.
*
* Currently there is no pruning mechanism for the events table, so it will
* grow indefinitely. You may want to periodically prune the table manually.
*/
eventLogging?: boolean;
/**
* When hashing prompt histories, how many messages to trim from the end.
* If zero, only the full prompt hash will be stored.
* If greater than zero, for each number N, a hash of the prompt with the
* last N messages removed will be stored.
*
* Experimental function, config may change in future versions.
*/
eventLoggingTrim?: number;
/** Whether prompts and responses should be logged to persistent storage. */
promptLogging?: boolean;
/** Which prompt logging backend to use. */
@ -356,6 +382,12 @@ export const config: Config = {
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
serviceInfoPassword: getEnvWithDefault("SERVICE_INFO_PASSWORD", ""),
sqliteDataPath: getEnvWithDefault(
"SQLITE_DATA_PATH",
path.join(DATA_DIR, "database.sqlite")
),
eventLogging: getEnvWithDefault("EVENT_LOGGING", false),
eventLoggingTrim: getEnvWithDefault("EVENT_LOGGING_TRIM", 5),
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"),
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
@ -605,6 +637,9 @@ export const OMITTED_KEYS = [
"googleSheetsKey",
"firebaseKey",
"firebaseRtdbUrl",
"sqliteDataPath",
"eventLogging",
"eventLoggingTrim",
"gatekeeperStore",
"maxIpsPerUser",
"blockedOrigins",

View File

@ -22,6 +22,7 @@ import {
import { handleBlockingResponse } from "./handle-blocking-response";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
import { logEvent } from "./log-event";
import { saveImage } from "./save-image";
/**
@ -84,7 +85,8 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
trackKeyRateLimit,
countResponseTokens,
incrementUsage,
logPrompt
logPrompt,
logEvent
);
} else {
middlewareStack.push(
@ -96,6 +98,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
copyHttpHeaders,
saveImage,
logPrompt,
logEvent,
...apiMiddleware
);
}

View File

@ -0,0 +1,81 @@
import { createHash } from "crypto";
import { config } from "../../../config";
import { eventLogger } from "../../../shared/prompt-logging";
import { getModelFromBody, isTextGenerationRequest } from "../common";
import { ProxyResHandlerWithBody } from ".";
import {
OpenAIChatMessage,
AnthropicChatMessage,
} from "../../../shared/api-schemas";
/** If event logging is enabled, logs a chat completion event. */
export const logEvent: ProxyResHandlerWithBody = async (
_proxyRes,
req,
_res,
responseBody
) => {
if (!config.eventLogging) {
return;
}
if (typeof responseBody !== "object") {
throw new Error("Expected body to be an object");
}
if (!["openai", "anthropic-chat"].includes(req.outboundApi)) {
// only chat apis are supported
return;
}
if (!req.user) {
return;
}
const loggable = isTextGenerationRequest(req);
if (!loggable) return;
const messages = req.body.messages as
| OpenAIChatMessage[]
| AnthropicChatMessage[];
let hashes = [];
hashes.push(hashMessages(messages));
for (
let i = 1;
i <= Math.min(config.eventLoggingTrim!, messages.length);
i++
) {
hashes.push(hashMessages(messages.slice(0, -i)));
}
const model = getModelFromBody(req, responseBody);
const userToken = req.user!.token;
const family = req.modelFamily!;
eventLogger.logEvent({
ip: req.ip,
type: "chat_completion",
model,
family,
hashes,
userToken,
inputTokens: req.promptTokens ?? 0,
outputTokens: req.outputTokens ?? 0,
});
};
const hashMessages = (
messages: OpenAIChatMessage[] | AnthropicChatMessage[]
): string => {
let hasher = createHash("sha256");
let messageTexts = [];
for (const msg of messages) {
if (!["system", "user", "assistant"].includes(msg.role)) continue;
if (typeof msg.content === "string") {
messageTexts.push(msg.content);
} else if (Array.isArray(msg.content)) {
if (msg.content[0].type === "text") {
messageTexts.push(msg.content[0].text);
}
}
}
hasher.update(messageTexts.join("<|im_sep|>"));
return hasher.digest("hex");
};

View File

@ -22,6 +22,7 @@ import { init as initUserStore } from "./shared/users/user-store";
import { init as initTokenizers } from "./shared/tokenization";
import { checkOrigin } from "./proxy/check-origin";
import { sendErrorToClient } from "./proxy/middleware/response/error-generator";
import { initializeDatabase, getDatabase } from "./shared/database";
const PORT = config.port;
const BIND_ADDRESS = config.bindAddress;
@ -70,7 +71,10 @@ app.set("views", [
app.use("/user_content", express.static(USER_ASSETS_DIR, { maxAge: "2h" }));
app.use(
"/res",
express.static(path.join(__dirname, "..", "public"), { maxAge: "2h", etag: false })
express.static(path.join(__dirname, "..", "public"), {
maxAge: "2h",
etag: false,
})
);
app.get("/health", (_req, res) => res.sendStatus(200));
@ -139,6 +143,8 @@ async function start() {
await logQueue.start();
}
await initializeDatabase();
logger.info("Starting request queue...");
startRequestQueue();
@ -160,6 +166,23 @@ async function start() {
);
}
function cleanup() {
console.log("Shutting down...");
if (config.eventLogging) {
try {
const db = getDatabase();
db.close();
console.log("Closed sqlite database.");
} catch (error) {}
}
process.exit(0);
}
process.on("exit", () => cleanup());
process.on("SIGHUP", () => process.exit(128 + 1));
process.on("SIGINT", () => process.exit(128 + 2));
process.on("SIGTERM", () => process.exit(128 + 15));
function registerUncaughtExceptionHandler() {
process.on("uncaughtException", (err: any) => {
logger.error(

View File

@ -0,0 +1,89 @@
import type sqlite3 from "better-sqlite3";
import { config } from "../../config";
import { logger } from "../../logger";
import { migrations } from "./migrations";
export const DATABASE_VERSION = 3;
let database: sqlite3.Database | undefined;
let log = logger.child({ module: "database" });
export function getDatabase(): sqlite3.Database {
if (!database) {
throw new Error("Sqlite database not initialized.");
}
return database;
}
export async function initializeDatabase() {
if (!config.eventLogging) {
return;
}
log.info("Initializing database...");
const sqlite3 = await import("better-sqlite3");
database = sqlite3.default(config.sqliteDataPath);
migrateDatabase();
database.pragma("journal_mode = WAL");
log.info("Database initialized.");
}
export function migrateDatabase(
targetVersion = DATABASE_VERSION,
targetDb?: sqlite3.Database
) {
const db = targetDb || getDatabase();
const currentVersion = db.pragma("user_version", { simple: true });
assertNumber(currentVersion);
if (currentVersion === targetVersion) {
log.info("No migrations to run.");
return;
}
const direction = currentVersion < targetVersion ? "up" : "down";
const pending = migrations
.slice()
.sort((a, b) =>
direction === "up" ? a.version - b.version : b.version - a.version
)
.filter((m) =>
direction === "up"
? m.version > currentVersion && m.version <= targetVersion
: m.version > targetVersion && m.version <= currentVersion
);
if (pending.length === 0) {
log.warn("No pending migrations found.");
return;
}
for (const migration of pending) {
const { version, name, up, down } = migration;
if (
(direction === "up" && version > currentVersion) ||
(direction === "down" && version <= currentVersion)
) {
if (direction === "up") {
log.info({ name }, "Applying migration.");
up(db);
db.pragma("user_version = " + version);
} else {
log.info({ name }, "Reverting migration.");
down(db);
db.pragma("user_version = " + (version - 1));
}
}
}
log.info("Migrations applied.");
}
function assertNumber(value: unknown): asserts value is number {
if (typeof value !== "number") {
throw new Error("Expected number");
}
}
export { EventLogEntry } from "./repos/event";

View File

@ -0,0 +1,61 @@
import type sqlite3 from "better-sqlite3";
type Migration = {
name: string;
version: number;
up: (db: sqlite3.Database) => void;
down: (db: sqlite3.Database) => void;
};
export const migrations = [
{
name: "create db",
version: 1,
up: () => {},
down: () => {},
},
{
name: "add events table",
version: 2,
up: (db) => {
db.exec(
`CREATE TABLE IF NOT EXISTS events
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT NOT NULL,
ip TEXT NOT NULL,
date TEXT NOT NULL,
model TEXT NOT NULL,
family TEXT NOT NULL,
hashes TEXT NOT NULL,
userToken TEXT NOT NULL,
inputTokens INTEGER NOT NULL,
outputTokens INTEGER NOT NULL
)`
);
},
down: (db) => db.exec("DROP TABLE events"),
},
{
name: "add events indexes",
version: 3,
up: (db) => {
// language=SQLite
db.exec(
`BEGIN;
CREATE INDEX IF NOT EXISTS idx_events_userToken ON events (userToken);
CREATE INDEX IF NOT EXISTS idx_events_ip ON events (ip);
COMMIT;`
);
},
down: (db) => {
// language=SQLite
db.exec(
`BEGIN;
DROP INDEX idx_events_userToken;
DROP INDEX idx_events_ip;
COMMIT;`
);
},
},
] satisfies Migration[];

View File

@ -0,0 +1,85 @@
import { getDatabase } from "../index";
export interface EventLogEntry {
date: string;
ip: string;
type: "chat_completion";
model: string;
family: string;
/**
* Prompt hashes are SHA256.
* Each message is stripped of whitespace.
* Then joined by <|im_sep|>
* Then hashed.
* First hash: Full prompt.
* Next {trim} hashes: Hashes with last 1-{trim} messages removed.
*/
hashes: string[];
userToken: string;
inputTokens: number;
outputTokens: number;
}
export interface EventsRepo {
getUserEvents: (
userToken: string,
{ limit, cursor }: { limit: number; cursor?: string }
) => EventLogEntry[];
logEvent: (payload: EventLogEntry) => void;
}
export const eventsRepo: EventsRepo = {
getUserEvents: (userToken, { limit, cursor }) => {
const db = getDatabase();
const params = [];
let sql = `
SELECT *
FROM events
WHERE userToken = ?
`;
params.push(userToken);
if (cursor) {
sql += ` AND date < ?`;
params.push(cursor);
}
sql += ` ORDER BY date DESC LIMIT ?`;
params.push(limit);
return db.prepare(sql).all(params).map(marshalEventLogEntry);
},
logEvent: (payload) => {
const db = getDatabase();
db.prepare(
`
INSERT INTO events(date, ip, type, model, family, hashes, userToken, inputTokens, outputTokens)
VALUES (:date, :ip, :type, :model, :family, :hashes, :userToken, :inputTokens, :outputTokens)
`
).run({
date: payload.date,
ip: payload.ip,
type: payload.type,
model: payload.model,
family: payload.family,
hashes: payload.hashes.join(","),
userToken: payload.userToken,
inputTokens: payload.inputTokens,
outputTokens: payload.outputTokens,
});
},
};
function marshalEventLogEntry(row: any): EventLogEntry {
return {
date: row.date,
ip: row.ip,
type: row.type,
model: row.model,
family: row.family,
hashes: row.hashes.split(","),
userToken: row.userToken,
inputTokens: parseInt(row.inputTokens),
outputTokens: parseInt(row.outputTokens),
};
}

View File

@ -0,0 +1,10 @@
import { config } from "../../config";
import type { EventLogEntry } from "../database";
import { eventsRepo } from "../database/repos/event";
export const logEvent = (payload: Omit<EventLogEntry, "date">) => {
if (!config.eventLogging) {
return;
}
eventsRepo.logEvent({ ...payload, date: new Date().toISOString() });
};

View File

@ -23,3 +23,4 @@ export interface LogBackend {
}
export * as logQueue from "./log-queue";
export * as eventLogger from "./event-logger";

View File

@ -57,7 +57,7 @@ export function makeOptionalPropsNullable<Schema extends z.AnyZodObject>(
) {
const entries = Object.entries(schema.shape) as [
keyof Schema["shape"],
z.ZodTypeAny
z.ZodTypeAny,
][];
const newProps = entries.reduce(
(acc, [key, value]) => {
@ -84,3 +84,12 @@ export function redactIp(ip: string) {
export function assertNever(x: never): never {
throw new Error(`Called assertNever with argument ${x}.`);
}
export function encodeCursor(v: string) {
return Buffer.from(JSON.stringify(v)).toString("base64");
}
export function decodeCursor(cursor?: string) {
if (!cursor) return null;
return JSON.parse(Buffer.from(cursor, "base64").toString("utf-8"));
}

View File

@ -204,7 +204,11 @@
if (solution) {
return;
}
workers.forEach((w) => w.postMessage({ type: "stop" }));
workers.forEach((w, i) => {
w.postMessage({ type: "stop" });
setTimeout(() => w.terminate(), 1000 + i * 100)
});
workers = [];
active = false;
solution = e.data.nonce;
document.getElementById("captcha-result").textContent =