From 20c064394a54cb5efd9ac1d77bff544c05e24a3d Mon Sep 17 00:00:00 2001 From: khanon Date: Tue, 14 Nov 2023 05:41:19 +0000 Subject: [PATCH] OpenAI DALL-E Image Generation (khanon/oai-reverse-proxy!52) --- .env.example | 23 +- data/.gitignore | 2 + data/user-files/.gitkeep | 0 docker/huggingface/Dockerfile | 2 + docker/render/Dockerfile | 2 + docs/dall-e-configuration.md | 71 +++ docs/deploy-huggingface.md | 2 + package-lock.json | 406 +++++++++++++++++- package.json | 2 + src/config.ts | 52 ++- src/info-page.ts | 49 ++- src/proxy/anthropic.ts | 5 +- src/proxy/aws.ts | 5 +- src/proxy/middleware/common.ts | 22 +- .../request/add-anthropic-preamble.ts | 4 +- src/proxy/middleware/request/add-key.ts | 17 +- .../middleware/request/apply-quota-limits.ts | 17 +- .../request/block-zoomer-origins.ts | 5 - .../middleware/request/count-prompt-tokens.ts | 12 +- src/proxy/middleware/request/finalize-body.ts | 5 + .../middleware/request/language-filter.ts | 1 + .../middleware/request/limit-completions.ts | 6 +- src/proxy/middleware/request/rewrite.ts | 1 + .../request/transform-outbound-payload.ts | 73 +++- .../request/validate-context-size.ts | 10 +- src/proxy/middleware/response/index.ts | 41 +- src/proxy/middleware/response/log-prompt.ts | 45 +- src/proxy/middleware/response/save-image.ts | 27 ++ .../response/streaming/event-aggregator.ts | 5 +- .../streaming/sse-message-transformer.ts | 3 +- src/proxy/openai-image.ts | 153 +++++++ src/proxy/openai.ts | 77 ++-- src/proxy/palm.ts | 5 +- src/proxy/queue.ts | 46 +- src/proxy/rate-limit.ts | 46 +- src/proxy/routes.ts | 2 + src/server.ts | 19 +- src/shared/file-storage/image-history.ts | 35 ++ .../file-storage/mirror-generated-image.ts | 75 ++++ src/shared/file-storage/setup-assets-dir.ts | 20 + .../key-management/anthropic/provider.ts | 16 +- src/shared/key-management/aws/provider.ts | 12 +- src/shared/key-management/index.ts | 31 +- src/shared/key-management/key-pool.ts | 39 +- src/shared/key-management/openai/checker.ts | 7 +- src/shared/key-management/openai/provider.ts | 19 +- src/shared/key-management/palm/provider.ts | 8 +- src/shared/models.ts | 16 +- src/shared/prompt-logging/backends/sheets.ts | 2 +- src/shared/stats.ts | 3 + src/shared/streaming.ts | 5 +- src/shared/tokenization/openai.ts | 60 +++ src/shared/tokenization/tokenizer.ts | 12 + src/shared/users/schema.ts | 1 + src/shared/users/user-store.ts | 79 ++-- src/types/custom.d.ts | 3 +- 56 files changed, 1401 insertions(+), 305 deletions(-) create mode 100644 data/.gitignore create mode 100644 data/user-files/.gitkeep create mode 100644 docs/dall-e-configuration.md create mode 100644 src/proxy/middleware/response/save-image.ts create mode 100644 src/proxy/openai-image.ts create mode 100644 src/shared/file-storage/image-history.ts create mode 100644 src/shared/file-storage/mirror-generated-image.ts create mode 100644 src/shared/file-storage/setup-assets-dir.ts diff --git a/.env.example b/.env.example index 8357a18..a1e9cd2 100644 --- a/.env.example +++ b/.env.example @@ -11,8 +11,10 @@ # The title displayed on the info page. # SERVER_TITLE=Coom Tunnel -# Model requests allowed per minute per user. -# MODEL_RATE_LIMIT=4 +# Text model requests allowed per minute per user. +# TEXT_MODEL_RATE_LIMIT=4 +# Image model requests allowed per minute per user. +# IMAGE_MODEL_RATE_LIMIT=2 # Max number of context tokens a user can request at once. # Increase this if your proxy allow GPT 32k or 128k context @@ -31,10 +33,11 @@ # CHECK_KEYS=true # Which model types users are allowed to access. -# If you want to restrict access to certain models, uncomment the line below and list only the models you want to allow, -# separated by commas. By default, all models are allowed. The following model families are recognized: -# turbo | gpt4 | gpt4-32k | gpt4-turbo | claude | bison | aws-claude -# ALLOWED_MODEL_FAMILIES=turbo,gpt4-turbo,aws-claude +# The following model families are recognized: +# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude +# By default, all models are allowed except for 'dall-e'. To allow DALL-E image +# generation, uncomment the line below and add 'dall-e' to the list. +# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude # URLs from which requests will be blocked. # BLOCKED_ORIGINS=reddit.com,9gag.com @@ -82,10 +85,18 @@ # ALLOW_NICKNAME_CHANGES=true # Default token quotas for each model family. (0 for unlimited) +# DALL-E "tokens" are counted at a rate of 100000 tokens per US$1.00 generated, +# which is similar to the cost of GPT-4 Turbo. +# DALL-E 3 costs around US$0.10 per image (10000 tokens). +# See `docs/dall-e-configuration.md` for more information. # TOKEN_QUOTA_TURBO=0 # TOKEN_QUOTA_GPT4=0 # TOKEN_QUOTA_GPT4_32K=0 +# TOKEN_QUOTA_GPT4_TURBO=0 +# TOKEN_QUOTA_DALL_E=0 # TOKEN_QUOTA_CLAUDE=0 +# TOKEN_QUOTA_BISON=0 +# TOKEN_QUOTA_AWS_CLAUDE=0 # How often to refresh token quotas. (hourly | daily) # Leave unset to never automatically refresh quotas. diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..377ccd3 --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitkeep diff --git a/data/user-files/.gitkeep b/data/user-files/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docker/huggingface/Dockerfile b/docker/huggingface/Dockerfile index eef259f..7ab1c36 100644 --- a/docker/huggingface/Dockerfile +++ b/docker/huggingface/Dockerfile @@ -3,6 +3,8 @@ RUN apt-get update && \ apt-get install -y git RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app WORKDIR /app +RUN chown -R 1000:1000 /app +USER 1000 RUN npm install COPY Dockerfile greeting.md* .env* ./ RUN npm run build diff --git a/docker/render/Dockerfile b/docker/render/Dockerfile index 67731c1..3bc7ed8 100644 --- a/docker/render/Dockerfile +++ b/docker/render/Dockerfile @@ -17,6 +17,8 @@ ARG GREETING_URL RUN if [ -n "$GREETING_URL" ]; then \ curl -sL "$GREETING_URL" > greeting.md; \ fi +RUN chown -R 1000:1000 /app +USER 1000 COPY package*.json greeting.md* ./ RUN npm install COPY . . diff --git a/docs/dall-e-configuration.md b/docs/dall-e-configuration.md new file mode 100644 index 0000000..3377e32 --- /dev/null +++ b/docs/dall-e-configuration.md @@ -0,0 +1,71 @@ +# Configuring the proxy for DALL-E + +The proxy supports DALL-E 2 and DALL-E 3 image generation via the `/proxy/openai-images` endpoint. By default it is disabled as it is somewhat expensive and potentially more open to abuse than text generation. + +- [Updating your Dockerfile](#updating-your-dockerfile) +- [Enabling DALL-E](#enabling-dall-e) +- [Setting quotas](#setting-quotas) +- [Rate limiting](#rate-limiting) + +## Updating your Dockerfile +If you are using a previous version of the Dockerfile supplied with the proxy, it doesn't have the necessary permissions to let the proxy save temporary files. + +You can replace the entire thing with the new Dockerfile at [./docker/huggingface/Dockerfile](../docker/huggingface/Dockerfile) (or the equivalent for Render deployments). + +You can also modify your existing Dockerfile; just add the following lines after the `WORKDIR` line: + +```Dockerfile +# Existing +RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app +WORKDIR /app + +# Take ownership of the app directory and switch to the non-root user +RUN chown -R 1000:1000 /app +USER 1000 + +# Existing +RUN npm install +``` + +## Enabling DALL-E +Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL-E. For example: + +``` +# GPT3.5 Turbo, GPT-4, GPT-4 Turbo, and DALL-E +ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e + +# All models as of this writing +ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e +``` + +Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models. + +## Setting quotas +DALL-E doesn't bill by token like text generation models. Instead there is a fixed cost per image generated, depending on the model, image size, and selected quality. + +The proxy still uses tokens to set quotas for users. The cost for each generated image will be converted to "tokens" at a rate of 100000 tokens per US$1.00. This works out to a similar cost-per-token as GPT-4 Turbo, so you can use similar token quotas for both. + +Use `TOKEN_QUOTA_DALL_E` to set the default quota for image generation. Otherwise it works the same as token quotas for other models. + +``` +# ~50 standard DALL-E images per refresh period, or US$2.00 +TOKEN_QUOTA_DALL_E=200000 +``` + +Refer to [https://openai.com/pricing](https://openai.com/pricing) for the latest pricing information. As of this writing, the cheapest DALL-E 3 image costs $0.04 per generation, which works out to 4000 tokens. Higher resolution and quality settings can cost up to $0.12 per image, or 12000 tokens. + +## Rate limiting +The old `MODEL_RATE_LIMIT` setting has been split into `TEXT_MODEL_RATE_LIMIT` and `IMAGE_MODEL_RATE_LIMIT`. Whatever value you previously set for `MODEL_RATE_LIMIT` will be used for text models. + +If you don't specify a `IMAGE_MODEL_RATE_LIMIT`, it defaults to half of the `TEXT_MODEL_RATE_LIMIT`, to a minimum of 1 image per minute. + +``` +# 4 text generations per minute, 2 images per minute +TEXT_MODEL_RATE_LIMIT=4 +IMAGE_MODEL_RATE_LIMIT=2 +``` + +If a prompt is filtered by OpenAI's content filter, it won't count towards the rate limit. + +## Hiding recent images +By default, the proxy shows the last 12 recently generated images by users. You can hide this section by setting `SHOW_RECENT_IMAGES` to `false`. diff --git a/docs/deploy-huggingface.md b/docs/deploy-huggingface.md index 4a112c0..968b1c9 100644 --- a/docs/deploy-huggingface.md +++ b/docs/deploy-huggingface.md @@ -25,6 +25,8 @@ RUN apt-get update && \ apt-get install -y git RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app WORKDIR /app +RUN chown -R 1000:1000 /app +USER 1000 RUN npm install COPY Dockerfile greeting.md* .env* ./ RUN npm run build diff --git a/package-lock.json b/package-lock.json index 72a8979..d87d555 100644 --- a/package-lock.json +++ b/package-lock.json @@ -15,6 +15,7 @@ "@smithy/signature-v4": "^2.0.10", "@smithy/types": "^2.3.4", "axios": "^1.3.5", + "check-disk-space": "^3.4.0", "cookie-parser": "^1.4.6", "copyfiles": "^2.4.1", "cors": "^2.8.5", @@ -33,6 +34,7 @@ "pino": "^8.11.0", "pino-http": "^8.3.3", "sanitize-html": "^2.11.0", + "sharp": "^0.32.6", "showdown": "^2.1.0", "tiktoken": "^1.0.10", "uuid": "^9.0.0", @@ -1373,15 +1375,20 @@ } }, "node_modules/axios": { - "version": "1.3.5", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.3.5.tgz", - "integrity": "sha512-glL/PvG/E+xCWwV8S6nCHcrfg1exGx7vxyUIivIA1iL7BIh6bePylCfVHwp6k13ao7SATxB6imau2kqY+I67kw==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", + "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", "dependencies": { "follow-redirects": "^1.15.0", "form-data": "^4.0.0", "proxy-from-env": "^1.1.0" } }, + "node_modules/b4a": { + "version": "1.6.4", + "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz", + "integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw==" + }, "node_modules/balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", @@ -1423,6 +1430,52 @@ "node": ">=8" } }, + "node_modules/bl": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz", + "integrity": "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==", + "dependencies": { + "buffer": "^5.5.0", + "inherits": "^2.0.4", + "readable-stream": "^3.4.0" + } + }, + "node_modules/bl/node_modules/buffer": { + "version": "5.7.1", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz", + "integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.1.13" + } + }, + "node_modules/bl/node_modules/readable-stream": { + "version": "3.6.2", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz", + "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==", + "dependencies": { + "inherits": "^2.0.3", + "string_decoder": "^1.1.1", + "util-deprecate": "^1.0.1" + }, + "engines": { + "node": ">= 6" + } + }, "node_modules/bluebird": { "version": "3.7.2", "resolved": "https://registry.npmjs.org/bluebird/-/bluebird-3.7.2.tgz", @@ -1582,6 +1635,14 @@ "node": ">=8" } }, + "node_modules/check-disk-space": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/check-disk-space/-/check-disk-space-3.4.0.tgz", + "integrity": "sha512-drVkSqfwA+TvuEhFipiR1OC9boEGZL5RrWvVsOthdcvQNXyCCuKkEiTOTXZ7qxSf/GLwq4GvzfrQD/Wz325hgw==", + "engines": { + "node": ">=16" + } + }, "node_modules/chokidar": { "version": "3.5.3", "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", @@ -1609,6 +1670,11 @@ "fsevents": "~2.3.2" } }, + "node_modules/chownr": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz", + "integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==" + }, "node_modules/cliui": { "version": "8.0.1", "resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz", @@ -1623,6 +1689,18 @@ "node": ">=12" } }, + "node_modules/color": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/color/-/color-4.2.3.tgz", + "integrity": "sha512-1rXeuUUiGGrykh+CeBdu5Ie7OJwinCgQY0bc7GCRxy5xVHy+moaqkpL/jqQq0MtQOeYcrqEz4abc5f0KtU7W4A==", + "dependencies": { + "color-convert": "^2.0.1", + "color-string": "^1.9.0" + }, + "engines": { + "node": ">=12.5.0" + } + }, "node_modules/color-convert": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", @@ -1639,6 +1717,15 @@ "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==" }, + "node_modules/color-string": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/color-string/-/color-string-1.9.1.tgz", + "integrity": "sha512-shrVawQFojnZv6xM40anx4CkoDP+fZsw/ZerEMsW/pyzsRbElpsL/DBVW7q3ExxwusdNXI3lXpuhEZkzs8p5Eg==", + "dependencies": { + "color-name": "^1.0.0", + "simple-swizzle": "^0.2.2" + } + }, "node_modules/colorette": { "version": "2.0.20", "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz", @@ -2000,6 +2087,28 @@ "ms": "2.0.0" } }, + "node_modules/decompress-response": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-6.0.0.tgz", + "integrity": "sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==", + "dependencies": { + "mimic-response": "^3.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/deep-extend": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz", + "integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==", + "engines": { + "node": ">=4.0.0" + } + }, "node_modules/deep-is": { "version": "0.1.4", "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz", @@ -2039,6 +2148,14 @@ "npm": "1.2.8000 || >= 1.4.16" } }, + "node_modules/detect-libc": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.2.tgz", + "integrity": "sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==", + "engines": { + "node": ">=8" + } + }, "node_modules/diff": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz", @@ -2188,7 +2305,6 @@ "version": "1.4.4", "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz", "integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==", - "devOptional": true, "dependencies": { "once": "^1.4.0" } @@ -2473,6 +2589,14 @@ "node": ">=0.8.x" } }, + "node_modules/expand-template": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz", + "integrity": "sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==", + "engines": { + "node": ">=6" + } + }, "node_modules/express": { "version": "4.18.2", "resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz", @@ -2557,6 +2681,11 @@ "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", "optional": true }, + "node_modules/fast-fifo": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz", + "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ==" + }, "node_modules/fast-levenshtein": { "version": "2.0.6", "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", @@ -2718,6 +2847,11 @@ "node": ">= 0.6" } }, + "node_modules/fs-constants": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz", + "integrity": "sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==" + }, "node_modules/fs.realpath": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", @@ -2795,6 +2929,11 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/github-from-package": { + "version": "0.0.0", + "resolved": "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz", + "integrity": "sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==" + }, "node_modules/glob": { "version": "8.1.0", "resolved": "https://registry.npmjs.org/glob/-/glob-8.1.0.tgz", @@ -3246,6 +3385,11 @@ "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==" }, + "node_modules/ini": { + "version": "1.3.8", + "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz", + "integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==" + }, "node_modules/ipaddr.js": { "version": "1.9.1", "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz", @@ -3254,6 +3398,11 @@ "node": ">= 0.10" } }, + "node_modules/is-arrayish": { + "version": "0.3.2", + "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.3.2.tgz", + "integrity": "sha512-eVRqCvVlZbuw3GrM63ovNSNAeA1K16kaR/LRY/92w0zxQ5/1YzwblUX652i4Xs9RwAGjW9d9y6X88t8OaAJfWQ==" + }, "node_modules/is-binary-path": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", @@ -3780,6 +3929,17 @@ "node": ">= 0.6" } }, + "node_modules/mimic-response": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-3.1.0.tgz", + "integrity": "sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/minimatch": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", @@ -3810,6 +3970,11 @@ "node": ">=10" } }, + "node_modules/mkdirp-classic": { + "version": "0.5.3", + "resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz", + "integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==" + }, "node_modules/ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", @@ -3860,6 +4025,11 @@ "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" } }, + "node_modules/napi-build-utils": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/napi-build-utils/-/napi-build-utils-1.0.2.tgz", + "integrity": "sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==" + }, "node_modules/negotiator": { "version": "0.6.3", "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz", @@ -3868,6 +4038,22 @@ "node": ">= 0.6" } }, + "node_modules/node-abi": { + "version": "3.51.0", + "resolved": "https://registry.npmjs.org/node-abi/-/node-abi-3.51.0.tgz", + "integrity": "sha512-SQkEP4hmNWjlniS5zdnfIXTk1x7Ome85RDzHlTbBtzE97Gfwz/Ipw4v/Ryk20DWIy3yCNVLVlGKApCnmvYoJbA==", + "dependencies": { + "semver": "^7.3.5" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/node-addon-api": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz", + "integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA==" + }, "node_modules/node-fetch": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.9.tgz", @@ -4217,6 +4403,70 @@ "node": "^10 || ^12 || >=14" } }, + "node_modules/prebuild-install": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.1.tgz", + "integrity": "sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==", + "dependencies": { + "detect-libc": "^2.0.0", + "expand-template": "^2.0.3", + "github-from-package": "0.0.0", + "minimist": "^1.2.3", + "mkdirp-classic": "^0.5.3", + "napi-build-utils": "^1.0.1", + "node-abi": "^3.3.0", + "pump": "^3.0.0", + "rc": "^1.2.7", + "simple-get": "^4.0.0", + "tar-fs": "^2.0.0", + "tunnel-agent": "^0.6.0" + }, + "bin": { + "prebuild-install": "bin.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/prebuild-install/node_modules/readable-stream": { + "version": "3.6.2", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz", + "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==", + "dependencies": { + "inherits": "^2.0.3", + "string_decoder": "^1.1.1", + "util-deprecate": "^1.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/prebuild-install/node_modules/tar-fs": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz", + "integrity": "sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==", + "dependencies": { + "chownr": "^1.1.1", + "mkdirp-classic": "^0.5.2", + "pump": "^3.0.0", + "tar-stream": "^2.1.4" + } + }, + "node_modules/prebuild-install/node_modules/tar-stream": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-2.2.0.tgz", + "integrity": "sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==", + "dependencies": { + "bl": "^4.0.3", + "end-of-stream": "^1.4.1", + "fs-constants": "^1.0.0", + "inherits": "^2.0.3", + "readable-stream": "^3.1.1" + }, + "engines": { + "node": ">=6" + } + }, "node_modules/prettier": { "version": "3.0.3", "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz", @@ -4352,7 +4602,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz", "integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==", - "dev": true, "dependencies": { "end-of-stream": "^1.1.0", "once": "^1.3.1" @@ -4372,6 +4621,11 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/queue-tick": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz", + "integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag==" + }, "node_modules/quick-format-unescaped": { "version": "4.0.4", "resolved": "https://registry.npmjs.org/quick-format-unescaped/-/quick-format-unescaped-4.0.4.tgz", @@ -4407,6 +4661,28 @@ "node": ">= 0.8" } }, + "node_modules/rc": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz", + "integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==", + "dependencies": { + "deep-extend": "^0.6.0", + "ini": "~1.3.0", + "minimist": "^1.2.0", + "strip-json-comments": "~2.0.1" + }, + "bin": { + "rc": "cli.js" + } + }, + "node_modules/rc/node_modules/strip-json-comments": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz", + "integrity": "sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/readable-stream": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.3.0.tgz", @@ -4615,9 +4891,9 @@ "dev": true }, "node_modules/semver": { - "version": "7.5.3", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.3.tgz", - "integrity": "sha512-QBlUtyVk/5EeHbi7X0fw6liDZc7BBmEaSYn01fMU1OUYbf6GPsbTtd8WmnqbI20SeycoHSeiybkE/q1Q+qlThQ==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dependencies": { "lru-cache": "^6.0.0" }, @@ -4675,6 +4951,28 @@ "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==" }, + "node_modules/sharp": { + "version": "0.32.6", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", + "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==", + "hasInstallScript": true, + "dependencies": { + "color": "^4.2.3", + "detect-libc": "^2.0.2", + "node-addon-api": "^6.1.0", + "prebuild-install": "^7.1.1", + "semver": "^7.5.4", + "simple-get": "^4.0.1", + "tar-fs": "^3.0.4", + "tunnel-agent": "^0.6.0" + }, + "engines": { + "node": ">=14.15.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, "node_modules/shell-quote": { "version": "1.8.1", "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz", @@ -4712,6 +5010,57 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/simple-concat": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/simple-concat/-/simple-concat-1.0.1.tgz", + "integrity": "sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ] + }, + "node_modules/simple-get": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/simple-get/-/simple-get-4.0.1.tgz", + "integrity": "sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "decompress-response": "^6.0.0", + "once": "^1.3.1", + "simple-concat": "^1.0.0" + } + }, + "node_modules/simple-swizzle": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/simple-swizzle/-/simple-swizzle-0.2.2.tgz", + "integrity": "sha512-JA//kQgZtbuY83m+xT+tXJkmJncGMTFT+C+g2h2R9uxkYIrE2yy9sgmcLhCnw57/WSD+Eh3J97FPEDFnbXnDUg==", + "dependencies": { + "is-arrayish": "^0.3.1" + } + }, "node_modules/simple-update-notifier": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/simple-update-notifier/-/simple-update-notifier-2.0.0.tgz", @@ -4809,11 +5158,19 @@ "node": ">=10.0.0" } }, + "node_modules/streamx": { + "version": "2.15.4", + "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.4.tgz", + "integrity": "sha512-uSXKl88bibiUCQ1eMpItRljCzDENcDx18rsfDmV79r0e/ThfrAwxG4Y2FarQZ2G4/21xcOKmFFd1Hue+ZIDwHw==", + "dependencies": { + "fast-fifo": "^1.1.0", + "queue-tick": "^1.0.1" + } + }, "node_modules/string_decoder": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", - "devOptional": true, "dependencies": { "safe-buffer": "~5.2.0" } @@ -4872,6 +5229,26 @@ "node": ">=4" } }, + "node_modules/tar-fs": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz", + "integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==", + "dependencies": { + "mkdirp-classic": "^0.5.2", + "pump": "^3.0.0", + "tar-stream": "^3.1.5" + } + }, + "node_modules/tar-stream": { + "version": "3.1.6", + "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz", + "integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==", + "dependencies": { + "b4a": "^1.6.4", + "fast-fifo": "^1.2.0", + "streamx": "^2.15.0" + } + }, "node_modules/teeny-request": { "version": "8.0.3", "resolved": "https://registry.npmjs.org/teeny-request/-/teeny-request-8.0.3.tgz", @@ -5047,6 +5424,17 @@ "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==" }, + "node_modules/tunnel-agent": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz", + "integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==", + "dependencies": { + "safe-buffer": "^5.0.1" + }, + "engines": { + "node": "*" + } + }, "node_modules/type-is": { "version": "1.6.18", "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz", diff --git a/package.json b/package.json index e765c18..6fe95f7 100644 --- a/package.json +++ b/package.json @@ -23,6 +23,7 @@ "@smithy/signature-v4": "^2.0.10", "@smithy/types": "^2.3.4", "axios": "^1.3.5", + "check-disk-space": "^3.4.0", "cookie-parser": "^1.4.6", "copyfiles": "^2.4.1", "cors": "^2.8.5", @@ -41,6 +42,7 @@ "pino": "^8.11.0", "pino-http": "^8.3.3", "sanitize-html": "^2.11.0", + "sharp": "^0.32.6", "showdown": "^2.1.0", "tiktoken": "^1.0.10", "uuid": "^9.0.0", diff --git a/src/config.ts b/src/config.ts index 72f555b..916d78d 100644 --- a/src/config.ts +++ b/src/config.ts @@ -1,12 +1,17 @@ import dotenv from "dotenv"; import type firebase from "firebase-admin"; +import path from "path"; import pino from "pino"; import type { ModelFamily } from "./shared/models"; +import { MODEL_FAMILIES } from "./shared/models"; dotenv.config(); const startupLogger = pino({ level: "debug" }).child({ module: "startup" }); const isDev = process.env.NODE_ENV !== "production"; +export const DATA_DIR = path.join(__dirname, "..", "data"); +export const USER_ASSETS_DIR = path.join(DATA_DIR, "user-files"); + type Config = { /** The port the proxy server will listen on. */ port: number; @@ -75,8 +80,10 @@ type Config = { * `maxIpsPerUser` limit, or if only connections from new IPs are be rejected. */ maxIpsAutoBan: boolean; - /** Per-IP limit for requests per minute to OpenAI's completions endpoint. */ - modelRateLimit: number; + /** Per-IP limit for requests per minute to text and chat models. */ + textModelRateLimit: number; + /** Per-IP limit for requests per minute to image generation models. */ + imageModelRateLimit: number; /** * For OpenAI, the maximum number of context tokens (prompt + max output) a * user can request before their request is rejected. @@ -157,6 +164,8 @@ type Config = { quotaRefreshPeriod?: "hourly" | "daily" | string; /** Whether to allow users to change their own nicknames via the UI. */ allowNicknameChanges: boolean; + /** Whether to show recent DALL-E image generations on the homepage. */ + showRecentImages: boolean; /** * If true, cookies will be set without the `Secure` attribute, allowing * the admin UI to used over HTTP. @@ -180,7 +189,8 @@ export const config: Config = { maxIpsAutoBan: getEnvWithDefault("MAX_IPS_AUTO_BAN", true), firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined), firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined), - modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4), + textModelRateLimit: getEnvWithDefault("TEXT_MODEL_RATE_LIMIT", 4), + imageModelRateLimit: getEnvWithDefault("IMAGE_MODEL_RATE_LIMIT", 4), maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 16384), maxContextTokensAnthropic: getEnvWithDefault( "MAX_CONTEXT_TOKENS_ANTHROPIC", @@ -225,17 +235,19 @@ export const config: Config = { "You must be over the age of majority in your country to use this service." ), blockRedirect: getEnvWithDefault("BLOCK_REDIRECT", "https://www.9gag.com"), - tokenQuota: { - turbo: getEnvWithDefault("TOKEN_QUOTA_TURBO", 0), - gpt4: getEnvWithDefault("TOKEN_QUOTA_GPT4", 0), - "gpt4-32k": getEnvWithDefault("TOKEN_QUOTA_GPT4_32K", 0), - "gpt4-turbo": getEnvWithDefault("TOKEN_QUOTA_GPT4_TURBO", 0), - claude: getEnvWithDefault("TOKEN_QUOTA_CLAUDE", 0), - bison: getEnvWithDefault("TOKEN_QUOTA_BISON", 0), - "aws-claude": getEnvWithDefault("TOKEN_QUOTA_AWS_CLAUDE", 0), - }, + tokenQuota: MODEL_FAMILIES.reduce( + (acc, family: ModelFamily) => { + acc[family] = getEnvWithDefault( + `TOKEN_QUOTA_${family.toUpperCase().replace(/-/g, "_")}`, + 0 + ) as number; + return acc; + }, + {} as { [key in ModelFamily]: number } + ), quotaRefreshPeriod: getEnvWithDefault("QUOTA_REFRESH_PERIOD", undefined), allowNicknameChanges: getEnvWithDefault("ALLOW_NICKNAME_CHANGES", true), + showRecentImages: getEnvWithDefault("SHOW_RECENT_IMAGES", true), useInsecureCookies: getEnvWithDefault("USE_INSECURE_COOKIES", isDev), } as const; @@ -252,6 +264,19 @@ function generateCookieSecret() { export const COOKIE_SECRET = generateCookieSecret(); export async function assertConfigIsValid() { + if (process.env.MODEL_RATE_LIMIT !== undefined) { + const limit = + parseInt(process.env.MODEL_RATE_LIMIT, 10) || config.textModelRateLimit; + + config.textModelRateLimit = limit; + config.imageModelRateLimit = Math.max(Math.floor(limit / 2), 1); + + startupLogger.warn( + { textLimit: limit, imageLimit: config.imageModelRateLimit }, + "MODEL_RATE_LIMIT is deprecated. Use TEXT_MODEL_RATE_LIMIT and IMAGE_MODEL_RATE_LIMIT instead." + ); + } + if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) { throw new Error( `Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.` @@ -332,6 +357,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [ "blockMessage", "blockRedirect", "allowNicknameChanges", + "showRecentImages", "useInsecureCookies", ]; @@ -428,5 +454,5 @@ function parseCsv(val: string): string[] { const regex = /(".*?"|[^",]+)(?=\s*,|\s*$)/g; const matches = val.match(regex) || []; - return matches.map(item => item.replace(/^"|"$/g, '').trim()); + return matches.map((item) => item.replace(/^"|"$/g, "").trim()); } diff --git a/src/info-page.ts b/src/info-page.ts index caa38e4..07602a5 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -14,6 +14,7 @@ import { getUniqueIps } from "./proxy/rate-limit"; import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue"; import { getTokenCostUsd, prettyTokens } from "./shared/stats"; import { assertNever } from "./shared/utils"; +import { getLastNImages } from "./shared/file-storage/image-history"; const INFO_PAGE_TTL = 2000; let infoPageHtml: string | undefined; @@ -94,6 +95,8 @@ function cacheInfoPageHtml(baseUrl: string) { const tokens = serviceStats.get("tokens") || 0; const tokenCost = serviceStats.get("tokenCost") || 0; + const allowDalle = config.allowedModelFamilies.includes("dall-e"); + const info = { uptime: Math.floor(process.uptime()), endpoints: { @@ -101,13 +104,16 @@ function cacheInfoPageHtml(baseUrl: string) { ...(openaiKeys ? { ["openai2"]: baseUrl + "/proxy/openai/turbo-instruct" } : {}), + ...(openaiKeys && allowDalle + ? { ["openai-image"]: baseUrl + "/proxy/openai-image" } + : {}), ...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}), ...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}), ...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}), }, proompts, tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`, - ...(config.modelRateLimit ? { proomptersNow: getUniqueIps() } : {}), + ...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}), openaiKeys, anthropicKeys, palmKeys, @@ -287,7 +293,6 @@ function getOpenAIInfo() { // Don't show trial/revoked keys for non-turbo families. // Generally those stats only make sense for the lowest-tier model. if (f !== "turbo") { - console.log("deleting", f); delete info[f]!.trialKeys; delete info[f]!.revokedKeys; } @@ -457,6 +462,9 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t if (customGreeting) { infoBody += `\n## Server Greeting\n${customGreeting}`; } + + infoBody += buildRecentImageSection(); + return converter.makeHtml(infoBody); } @@ -499,6 +507,43 @@ function getServerTitle() { return "OAI Reverse Proxy"; } +function buildRecentImageSection() { + if ( + !config.allowedModelFamilies.includes("dall-e") || + !config.showRecentImages + ) { + return ""; + } + + let html = `

Recent DALL-E Generations

`; + const recentImages = getLastNImages(12).reverse(); + if (recentImages.length === 0) { + html += `

No images yet.

`; + return html; + } + + html += `
`; + for (const { url, prompt } of recentImages) { + const thumbUrl = url.replace(/\.png$/, "_t.jpg"); + const escapedPrompt = escapeHtml(prompt); + html += `
+${escapedPrompt} +
`; + } + html += `
`; + + return html; +} + +function escapeHtml(unsafe: string) { + return unsafe + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, '''); +} + function getExternalUrlForHuggingfaceSpaceId(spaceId: string) { // Huggingface broke their amazon elb config and no longer sends the // x-forwarded-host header. This is a workaround. diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 98cc98b..ed1fffc 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -87,9 +87,8 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async ( body = transformAnthropicResponse(body, req); } - // TODO: Remove once tokenization is stable - if (req.debug) { - body.proxy_tokenizer_debug_info = req.debug; + if (req.tokenizerInfo) { + body.proxy_tokenizer = req.tokenizerInfo; } res.status(200).json(body); diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index c80aa6a..2f43762 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -73,9 +73,8 @@ const awsResponseHandler: ProxyResHandlerWithBody = async ( body = transformAwsResponse(body, req); } - // TODO: Remove once tokenization is stable - if (req.debug) { - body.proxy_tokenizer_debug_info = req.debug; + if (req.tokenizerInfo) { + body.proxy_tokenizer = req.tokenizerInfo; } // AWS does not confirm the model in the response, so we have to add it diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 99f8281..24cfef9 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -9,11 +9,10 @@ import { QuotaExceededError } from "./request/apply-quota-limits"; const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions"; const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions"; const OPENAI_EMBEDDINGS_ENDPOINT = "/v1/embeddings"; +const OPENAI_IMAGE_COMPLETION_ENDPOINT = "/v1/images/generations"; const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete"; -/** Returns true if we're making a request to a completion endpoint. */ -export function isCompletionRequest(req: Request) { - // 99% sure this function is not needed anymore +export function isTextGenerationRequest(req: Request) { return ( req.method === "POST" && [ @@ -24,6 +23,13 @@ export function isCompletionRequest(req: Request) { ); } +export function isImageGenerationRequest(req: Request) { + return ( + req.method === "POST" && + req.path.startsWith(OPENAI_IMAGE_COMPLETION_ENDPOINT) + ); +} + export function isEmbeddingsRequest(req: Request) { return ( req.method === "POST" && req.path.startsWith(OPENAI_EMBEDDINGS_ENDPOINT) @@ -53,8 +59,8 @@ export function writeErrorResponse( res.write(`data: [DONE]\n\n`); res.end(); } else { - if (req.debug && errorPayload.error) { - errorPayload.error.proxy_tokenizer_debug_info = req.debug; + if (req.tokenizerInfo && errorPayload.error) { + errorPayload.error.proxy_tokenizer = req.tokenizerInfo; } res.status(statusCode).json(errorPayload); } @@ -103,7 +109,7 @@ function classifyError(err: Error): { code: { enabled: false }, maxErrors: 3, transform: ({ issue, ...rest }) => { - return `At '${rest.pathComponent}', ${issue.message}`; + return `At '${rest.pathComponent}': ${issue.message}`; }, }); return { status: 400, userMessage, type: "proxy_validation_error" }; @@ -173,6 +179,8 @@ export function getCompletionFromBody(req: Request, body: Record) { return body.completion.trim(); case "google-palm": return body.candidates[0].output; + case "openai-image": + return body.data?.map((item: any) => item.url).join("\n"); default: assertNever(format); } @@ -184,6 +192,8 @@ export function getModelFromBody(req: Request, body: Record) { case "openai": case "openai-text": return body.model; + case "openai-image": + return req.body.model; case "anthropic": // Anthropic confirms the model in the response, but AWS Claude doesn't. return body.model || req.body.model; diff --git a/src/proxy/middleware/request/add-anthropic-preamble.ts b/src/proxy/middleware/request/add-anthropic-preamble.ts index 35f3602..cdab4f2 100644 --- a/src/proxy/middleware/request/add-anthropic-preamble.ts +++ b/src/proxy/middleware/request/add-anthropic-preamble.ts @@ -1,5 +1,5 @@ import { AnthropicKey, Key } from "../../../shared/key-management"; -import { isCompletionRequest } from "../common"; +import { isTextGenerationRequest } from "../common"; import { ProxyRequestMiddleware } from "."; /** @@ -11,7 +11,7 @@ export const addAnthropicPreamble: ProxyRequestMiddleware = ( _proxyReq, req ) => { - if (!isCompletionRequest(req) || req.key?.service !== "anthropic") { + if (!isTextGenerationRequest(req) || req.key?.service !== "anthropic") { return; } diff --git a/src/proxy/middleware/request/add-key.ts b/src/proxy/middleware/request/add-key.ts index cb7f9e3..bbf38b7 100644 --- a/src/proxy/middleware/request/add-key.ts +++ b/src/proxy/middleware/request/add-key.ts @@ -1,5 +1,5 @@ import { Key, OpenAIKey, keyPool } from "../../../shared/key-management"; -import { isCompletionRequest, isEmbeddingsRequest } from "../common"; +import { isEmbeddingsRequest } from "../common"; import { ProxyRequestMiddleware } from "."; import { assertNever } from "../../../shared/utils"; @@ -7,18 +7,6 @@ import { assertNever } from "../../../shared/utils"; export const addKey: ProxyRequestMiddleware = (proxyReq, req) => { let assignedKey: Key; - if (!isCompletionRequest(req)) { - // Horrible, horrible hack to stop the proxy from complaining about clients - // not sending a model when they are requesting the list of models (which - // requires a key, but obviously not a model). - - // I don't think this is needed anymore since models requests are no longer - // proxied to the upstream API. Everything going through this is either a - // completion request or a special case like OpenAI embeddings. - req.log.warn({ path: req.path }, "addKey called on non-completion request"); - req.body.model = "gpt-3.5-turbo"; - } - if (!req.inboundApi || !req.outboundApi) { const err = new Error( "Request API format missing. Did you forget to add the request preprocessor to your router?" @@ -54,6 +42,9 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => { throw new Error( "OpenAI Chat as an API translation target is not supported" ); + case "openai-image": + assignedKey = keyPool.get("dall-e-3"); + break; default: assertNever(req.outboundApi); } diff --git a/src/proxy/middleware/request/apply-quota-limits.ts b/src/proxy/middleware/request/apply-quota-limits.ts index 581de23..e7a637b 100644 --- a/src/proxy/middleware/request/apply-quota-limits.ts +++ b/src/proxy/middleware/request/apply-quota-limits.ts @@ -1,5 +1,5 @@ import { hasAvailableQuota } from "../../../shared/users/user-store"; -import { isCompletionRequest } from "../common"; +import { isImageGenerationRequest, isTextGenerationRequest } from "../common"; import { ProxyRequestMiddleware } from "."; export class QuotaExceededError extends Error { @@ -12,12 +12,19 @@ export class QuotaExceededError extends Error { } export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => { - if (!isCompletionRequest(req) || !req.user) { - return; - } + const subjectToQuota = + isTextGenerationRequest(req) || isImageGenerationRequest(req); + if (!subjectToQuota || !req.user) return; const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0); - if (!hasAvailableQuota(req.user.token, req.body.model, requestedTokens)) { + if ( + !hasAvailableQuota({ + userToken: req.user.token, + model: req.body.model, + api: req.outboundApi, + requested: requestedTokens, + }) + ) { throw new QuotaExceededError( "You have exceeded your proxy token quota for this model.", { diff --git a/src/proxy/middleware/request/block-zoomer-origins.ts b/src/proxy/middleware/request/block-zoomer-origins.ts index 93f4c87..9efa404 100644 --- a/src/proxy/middleware/request/block-zoomer-origins.ts +++ b/src/proxy/middleware/request/block-zoomer-origins.ts @@ -1,4 +1,3 @@ -import { isCompletionRequest } from "../common"; import { ProxyRequestMiddleware } from "."; const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(","); @@ -15,10 +14,6 @@ class ForbiddenError extends Error { * stop getting emails asking for tech support. */ export const blockZoomerOrigins: ProxyRequestMiddleware = (_proxyReq, req) => { - if (!isCompletionRequest(req)) { - return; - } - const origin = req.headers.origin || req.headers.referer; if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) { // Venus-derivatives send a test prompt to check if the proxy is working. diff --git a/src/proxy/middleware/request/count-prompt-tokens.ts b/src/proxy/middleware/request/count-prompt-tokens.ts index 5cb0dd4..19f1cca 100644 --- a/src/proxy/middleware/request/count-prompt-tokens.ts +++ b/src/proxy/middleware/request/count-prompt-tokens.ts @@ -35,14 +35,18 @@ export const countPromptTokens: RequestPreprocessor = async (req) => { result = await countTokens({ req, prompt, service }); break; } + case "openai-image": { + req.outputTokens = 1; + result = await countTokens({ req, service }); + break; + } default: assertNever(service); } req.promptTokens = result.token_count; - // TODO: Remove once token counting is stable req.log.debug({ result: result }, "Counted prompt tokens."); - req.debug = req.debug ?? {}; - req.debug = { ...req.debug, ...result }; -}; \ No newline at end of file + req.tokenizerInfo = req.tokenizerInfo ?? {}; + req.tokenizerInfo = { ...req.tokenizerInfo, ...result }; +}; diff --git a/src/proxy/middleware/request/finalize-body.ts b/src/proxy/middleware/request/finalize-body.ts index bc62bf5..ac90e96 100644 --- a/src/proxy/middleware/request/finalize-body.ts +++ b/src/proxy/middleware/request/finalize-body.ts @@ -4,6 +4,11 @@ import type { ProxyRequestMiddleware } from "."; /** Finalize the rewritten request body. Must be the last rewriter. */ export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => { if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) { + // For image generation requests, remove stream flag. + if (req.outboundApi === "openai-image") { + delete req.body.stream; + } + const updatedBody = JSON.stringify(req.body); proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody)); (req as any).rawBody = Buffer.from(updatedBody); diff --git a/src/proxy/middleware/request/language-filter.ts b/src/proxy/middleware/request/language-filter.ts index 4bb5d25..32802bc 100644 --- a/src/proxy/middleware/request/language-filter.ts +++ b/src/proxy/middleware/request/language-filter.ts @@ -58,6 +58,7 @@ function getPromptFromRequest(req: Request) { ) .join("\n\n"); case "openai-text": + case "openai-image": return body.prompt; case "google-palm": return body.prompt.text; diff --git a/src/proxy/middleware/request/limit-completions.ts b/src/proxy/middleware/request/limit-completions.ts index c61fee3..44f583b 100644 --- a/src/proxy/middleware/request/limit-completions.ts +++ b/src/proxy/middleware/request/limit-completions.ts @@ -1,12 +1,12 @@ -import { isCompletionRequest } from "../common"; +import { isTextGenerationRequest } from "../common"; import { ProxyRequestMiddleware } from "."; /** - * Don't allow multiple completions to be requested to prevent abuse. + * Don't allow multiple text completions to be requested to prevent abuse. * OpenAI-only, Anthropic provides no such parameter. **/ export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => { - if (isCompletionRequest(req) && req.outboundApi === "openai") { + if (isTextGenerationRequest(req) && req.outboundApi === "openai") { const originalN = req.body?.n || 1; req.body.n = 1; if (originalN !== req.body.n) { diff --git a/src/proxy/middleware/request/rewrite.ts b/src/proxy/middleware/request/rewrite.ts index 3b62ff2..8cc078d 100644 --- a/src/proxy/middleware/request/rewrite.ts +++ b/src/proxy/middleware/request/rewrite.ts @@ -17,6 +17,7 @@ export const createOnProxyReqHandler = ({ // The streaming flag must be set before any other middleware runs, because // it may influence which other middleware a particular API pipeline wants // to run. + // Image generation requests can't be streamed. req.isStreaming = req.body.stream === true || req.body.stream === "true"; req.body.stream = req.isStreaming; diff --git a/src/proxy/middleware/request/transform-outbound-payload.ts b/src/proxy/middleware/request/transform-outbound-payload.ts index 5025158..f4de2a7 100644 --- a/src/proxy/middleware/request/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/transform-outbound-payload.ts @@ -2,7 +2,7 @@ import { Request } from "express"; import { z } from "zod"; import { config } from "../../../config"; import { OpenAIPromptMessage } from "../../../shared/tokenization"; -import { isCompletionRequest } from "../common"; +import { isTextGenerationRequest, isImageGenerationRequest } from "../common"; import { RequestPreprocessor } from "."; import { APIFormat } from "../../../shared/key-management"; @@ -88,6 +88,21 @@ const OpenAIV1TextCompletionSchema = z }) .merge(OpenAIV1ChatCompletionSchema.omit({ messages: true })); +// https://platform.openai.com/docs/api-reference/images/create +const OpenAIV1ImagesGenerationSchema = z.object({ + prompt: z.string().max(4000), + model: z.string().optional(), + quality: z.enum(["standard", "hd"]).optional().default("standard"), + n: z.number().int().min(1).max(4).optional().default(1), + response_format: z.enum(["url", "b64_json"]).optional(), + size: z + .enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]) + .optional() + .default("1024x1024"), + style: z.enum(["vivid", "natural"]).optional().default("vivid"), + user: z.string().optional(), +}); + // https://developers.generativeai.google/api/rest/generativelanguage/models/generateText const PalmV1GenerateTextSchema = z.object({ model: z.string(), @@ -110,6 +125,7 @@ const VALIDATORS: Record> = { anthropic: AnthropicV1CompleteSchema, openai: OpenAIV1ChatCompletionSchema, "openai-text": OpenAIV1TextCompletionSchema, + "openai-image": OpenAIV1ImagesGenerationSchema, "google-palm": PalmV1GenerateTextSchema, }; @@ -117,11 +133,10 @@ const VALIDATORS: Record> = { export const transformOutboundPayload: RequestPreprocessor = async (req) => { const sameService = req.inboundApi === req.outboundApi; const alreadyTransformed = req.retryCount > 0; - const notTransformable = !isCompletionRequest(req); + const notTransformable = + !isTextGenerationRequest(req) && !isImageGenerationRequest(req); - if (alreadyTransformed || notTransformable) { - return; - } + if (alreadyTransformed || notTransformable) return; if (sameService) { const result = VALIDATORS[req.inboundApi].safeParse(req.body); @@ -151,6 +166,11 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { return; } + if (req.inboundApi === "openai" && req.outboundApi === "openai-image") { + req.body = openaiToOpenaiImage(req); + return; + } + throw new Error( `'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.` ); @@ -226,6 +246,49 @@ function openaiToOpenaiText(req: Request) { return OpenAIV1TextCompletionSchema.parse(transformed); } +// Takes the last chat message and uses it verbatim as the image prompt. +function openaiToOpenaiImage(req: Request) { + const { body } = req; + const result = OpenAIV1ChatCompletionSchema.safeParse(body); + if (!result.success) { + req.log.warn( + { issues: result.error.issues, body }, + "Invalid OpenAI-to-OpenAI-image request" + ); + throw result.error; + } + + const { messages } = result.data; + const prompt = messages.filter((m) => m.role === "user").pop()?.content; + + if (body.stream) { + throw new Error( + "Streaming is not supported for image generation requests." + ); + } + + // Some frontends do weird things with the prompt, like prefixing it with a + // character name or wrapping the entire thing in quotes. We will look for + // the index of "Image:" and use everything after that as the prompt. + + const index = prompt?.toLowerCase().indexOf("image:"); + if (index === -1 || !prompt) { + throw new Error( + `Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).` + ); + } + + // TODO: Add some way to specify parameters via chat message + const transformed = { + model: body.model.includes("dall-e") ? body.model : "dall-e-3", + quality: "standard", + size: "1024x1024", + response_format: "url", + prompt: prompt.slice(index! + 6).trim(), + }; + return OpenAIV1ImagesGenerationSchema.parse(transformed); +} + function openaiToPalm(req: Request): z.infer { const { body } = req; const result = OpenAIV1ChatCompletionSchema.safeParse({ diff --git a/src/proxy/middleware/request/validate-context-size.ts b/src/proxy/middleware/request/validate-context-size.ts index 8883ccb..ee661c9 100644 --- a/src/proxy/middleware/request/validate-context-size.ts +++ b/src/proxy/middleware/request/validate-context-size.ts @@ -34,6 +34,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => { case "google-palm": proxyMax = BISON_MAX_CONTEXT; break; + case "openai-image": + return; default: assertNever(req.outboundApi); } @@ -81,10 +83,10 @@ export const validateContextSize: RequestPreprocessor = async (req) => { "Prompt size validated" ); - req.debug.prompt_tokens = promptTokens; - req.debug.completion_tokens = outputTokens; - req.debug.max_model_tokens = modelMax; - req.debug.max_proxy_tokens = proxyMax; + req.tokenizerInfo.prompt_tokens = promptTokens; + req.tokenizerInfo.completion_tokens = outputTokens; + req.tokenizerInfo.max_model_tokens = modelMax; + req.tokenizerInfo.max_proxy_tokens = proxyMax; }; function assertRequestHasTokenCounts( diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index b398975..c68d347 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -13,13 +13,16 @@ import { incrementTokenCount, } from "../../../shared/users/user-store"; import { assertNever } from "../../../shared/utils"; +import { refundLastAttempt } from "../../rate-limit"; import { getCompletionFromBody, - isCompletionRequest, + isImageGenerationRequest, + isTextGenerationRequest, writeErrorResponse, } from "../common"; import { handleStreamedResponse } from "./handle-streamed-response"; import { logPrompt } from "./log-prompt"; +import { saveImage } from "./save-image"; const DECODER_MAP = { gzip: util.promisify(zlib.gunzip), @@ -106,6 +109,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => { countResponseTokens, incrementUsage, copyHttpHeaders, + saveImage, logPrompt, ...apiMiddleware ); @@ -285,7 +289,16 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( switch (service) { case "openai": case "google-palm": - errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`; + if (errorPayload.error?.code === "content_policy_violation") { + errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`; + refundLastAttempt(req); + } else if (errorPayload.error?.code === "billing_hard_limit_reached") { + // For some reason, some models return this 400 error instead of the + // same 429 billing error that other models return. + handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload); + } else { + errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`; + } break; case "anthropic": case "aws": @@ -453,6 +466,7 @@ function handleOpenAIRateLimitError( const type = errorPayload.error?.type; switch (type) { case "insufficient_quota": + case "invalid_request_error": // this is the billing_hard_limit_reached error seen in some cases // Billing quota exceeded (key is dead, disable it) keyPool.disable(req.key!, "quota"); errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`; @@ -487,13 +501,22 @@ function handleOpenAIRateLimitError( } const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => { - if (isCompletionRequest(req)) { + if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) { const model = req.body.model; const tokensUsed = req.promptTokens! + req.outputTokens!; + req.log.debug( + { + model, + tokensUsed, + promptTokens: req.promptTokens, + outputTokens: req.outputTokens, + }, + `Incrementing usage for model` + ); keyPool.incrementUsage(req.key!, model, tokensUsed); if (req.user) { incrementPromptCount(req.user.token); - incrementTokenCount(req.user.token, model, tokensUsed); + incrementTokenCount(req.user.token, model, req.outboundApi, tokensUsed); } } }; @@ -504,6 +527,12 @@ const countResponseTokens: ProxyResHandlerWithBody = async ( _res, body ) => { + if (req.outboundApi === "openai-image") { + req.outputTokens = req.promptTokens; + req.promptTokens = 0; + return; + } + // This function is prone to breaking if the upstream API makes even minor // changes to the response format, especially for SSE responses. If you're // seeing errors in this function, check the reassembled response body from @@ -518,8 +547,8 @@ const countResponseTokens: ProxyResHandlerWithBody = async ( { service, tokens, prevOutputTokens: req.outputTokens }, `Counted tokens for completion` ); - if (req.debug) { - req.debug.completion_tokens = tokens; + if (req.tokenizerInfo) { + req.tokenizerInfo.completion_tokens = tokens; } req.outputTokens = tokens.token_count; diff --git a/src/proxy/middleware/response/log-prompt.ts b/src/proxy/middleware/response/log-prompt.ts index 0d634d3..7f3dc65 100644 --- a/src/proxy/middleware/response/log-prompt.ts +++ b/src/proxy/middleware/response/log-prompt.ts @@ -4,7 +4,8 @@ import { logQueue } from "../../../shared/prompt-logging"; import { getCompletionFromBody, getModelFromBody, - isCompletionRequest, + isImageGenerationRequest, + isTextGenerationRequest, } from "../common"; import { ProxyResHandlerWithBody } from "."; import { assertNever } from "../../../shared/utils"; @@ -23,11 +24,11 @@ export const logPrompt: ProxyResHandlerWithBody = async ( throw new Error("Expected body to be an object"); } - if (!isCompletionRequest(req)) { - return; - } + const loggable = + isTextGenerationRequest(req) || isImageGenerationRequest(req); + if (!loggable) return; - const promptPayload = getPromptForRequest(req); + const promptPayload = getPromptForRequest(req, responseBody); const promptFlattened = flattenMessages(promptPayload); const response = getCompletionFromBody(req, responseBody); const model = getModelFromBody(req, responseBody); @@ -46,7 +47,18 @@ type OaiMessage = { content: string; }; -const getPromptForRequest = (req: Request): string | OaiMessage[] => { +type OaiImageResult = { + prompt: string; + size: string; + style: string; + quality: string; + revisedPrompt?: string; +}; + +const getPromptForRequest = ( + req: Request, + responseBody: Record +): string | OaiMessage[] | OaiImageResult => { // Since the prompt logger only runs after the request has been proxied, we // can assume the body has already been transformed to the target API's // format. @@ -55,6 +67,14 @@ const getPromptForRequest = (req: Request): string | OaiMessage[] => { return req.body.messages; case "openai-text": return req.body.prompt; + case "openai-image": + return { + prompt: req.body.prompt, + size: req.body.size, + style: req.body.style, + quality: req.body.quality, + revisedPrompt: responseBody.data[0].revised_prompt, + }; case "anthropic": return req.body.prompt; case "google-palm": @@ -64,9 +84,14 @@ const getPromptForRequest = (req: Request): string | OaiMessage[] => { } }; -const flattenMessages = (messages: string | OaiMessage[]): string => { - if (typeof messages === "string") { - return messages.trim(); +const flattenMessages = ( + val: string | OaiMessage[] | OaiImageResult +): string => { + if (typeof val === "string") { + return val.trim(); } - return messages.map((m) => `${m.role}: ${m.content}`).join("\n"); + if (Array.isArray(val)) { + return val.map((m) => `${m.role}: ${m.content}`).join("\n"); + } + return val.prompt.trim(); }; diff --git a/src/proxy/middleware/response/save-image.ts b/src/proxy/middleware/response/save-image.ts new file mode 100644 index 0000000..937e9fc --- /dev/null +++ b/src/proxy/middleware/response/save-image.ts @@ -0,0 +1,27 @@ +import { ProxyResHandlerWithBody } from "./index"; +import { mirrorGeneratedImage, OpenAIImageGenerationResult } from "../../../shared/file-storage/mirror-generated-image"; + +export const saveImage: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + _res, + body, +) => { + if (req.outboundApi !== "openai-image") { + return; + } + + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + if (body.data) { + const baseUrl = req.protocol + "://" + req.get("host"); + const prompt = body.data[0].revised_prompt ?? req.body.prompt; + await mirrorGeneratedImage( + baseUrl, + prompt, + body as OpenAIImageGenerationResult + ); + } +}; diff --git a/src/proxy/middleware/response/streaming/event-aggregator.ts b/src/proxy/middleware/response/streaming/event-aggregator.ts index 55f0fb3..8db3da2 100644 --- a/src/proxy/middleware/response/streaming/event-aggregator.ts +++ b/src/proxy/middleware/response/streaming/event-aggregator.ts @@ -33,9 +33,10 @@ export class EventAggregator { case "anthropic": return mergeEventsForAnthropic(this.events); case "google-palm": - throw new Error("Google PaLM API does not support streaming responses"); + case "openai-image": + throw new Error(`SSE aggregation not supported for ${this.format}`); default: assertNever(this.format); } } -} \ No newline at end of file +} diff --git a/src/proxy/middleware/response/streaming/sse-message-transformer.ts b/src/proxy/middleware/response/streaming/sse-message-transformer.ts index 5bd0b8e..6da55b9 100644 --- a/src/proxy/middleware/response/streaming/sse-message-transformer.ts +++ b/src/proxy/middleware/response/streaming/sse-message-transformer.ts @@ -99,7 +99,8 @@ function getTransformer( ? anthropicV1ToOpenAI : anthropicV2ToOpenAI; case "google-palm": - throw new Error("Google PaLM does not support streaming responses"); + case "openai-image": + throw new Error(`SSE transformation not supported for ${responseApi}`); default: assertNever(responseApi); } diff --git a/src/proxy/openai-image.ts b/src/proxy/openai-image.ts new file mode 100644 index 0000000..2c5a63c --- /dev/null +++ b/src/proxy/openai-image.ts @@ -0,0 +1,153 @@ +import { RequestHandler, Router, Request } from "express"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { config } from "../config"; +import { logger } from "../logger"; +import { createQueueMiddleware } from "./queue"; +import { ipLimiter } from "./rate-limit"; +import { handleProxyError } from "./middleware/common"; +import { + addKey, + applyQuotaLimits, + blockZoomerOrigins, + createPreprocessorMiddleware, + finalizeBody, + stripHeaders, + createOnProxyReqHandler, +} from "./middleware/request"; +import { + createOnProxyResHandler, + ProxyResHandlerWithBody, +} from "./middleware/response"; +import { generateModelList } from "./openai"; +import { + mirrorGeneratedImage, + OpenAIImageGenerationResult, +} from "../shared/file-storage/mirror-generated-image"; + +const KNOWN_MODELS = ["dall-e-2", "dall-e-3"]; + +let modelListCache: any = null; +let modelListValid = 0; +const handleModelRequest: RequestHandler = (_req, res) => { + if (new Date().getTime() - modelListValid < 1000 * 60) return modelListCache; + const result = generateModelList(KNOWN_MODELS); + modelListCache = { object: "list", data: result }; + modelListValid = new Date().getTime(); + res.status(200).json(modelListCache); +}; + +const openaiImagesResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + if (config.promptLogging) { + const host = req.get("host"); + body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`; + } + + if (req.inboundApi === "openai") { + req.log.info("Transforming OpenAI image response to OpenAI chat format"); + body = transformResponseForChat(body as OpenAIImageGenerationResult, req); + } + + if (req.tokenizerInfo) { + body.proxy_tokenizer = req.tokenizerInfo; + } + + res.status(200).json(body); +}; + +/** + * Transforms a DALL-E image generation response into a chat response, simply + * embedding the image URL into the chat message as a Markdown image. + */ +function transformResponseForChat( + imageBody: OpenAIImageGenerationResult, + req: Request +): Record { + const prompt = imageBody.data[0].revised_prompt ?? req.body.prompt; + const content = imageBody.data + .map((item) => { + const { url, b64_json } = item; + if (b64_json) { + return `![${prompt}](data:image/png;base64,${b64_json})`; + } else { + return `![${prompt}](${url})`; + } + }) + .join("\n\n"); + + return { + id: "dalle-" + req.id, + object: "chat.completion", + created: Date.now(), + model: req.body.model, + usage: { + prompt_tokens: 0, + completion_tokens: req.outputTokens, + total_tokens: req.outputTokens, + }, + choices: [ + { + message: { role: "assistant", content }, + finish_reason: "stop", + index: 0, + }, + ], + }; +} + +const openaiImagesProxy = createQueueMiddleware({ + proxyMiddleware: createProxyMiddleware({ + target: "https://api.openai.com", + changeOrigin: true, + selfHandleResponse: true, + logger, + pathRewrite: { + "^/v1/chat/completions": "/v1/images/generations", + }, + on: { + proxyReq: createOnProxyReqHandler({ + pipeline: [ + applyQuotaLimits, + addKey, + blockZoomerOrigins, + stripHeaders, + finalizeBody, + ], + }), + proxyRes: createOnProxyResHandler([openaiImagesResponseHandler]), + error: handleProxyError, + }, + }), +}); + +const openaiImagesRouter = Router(); +openaiImagesRouter.get("/v1/models", handleModelRequest); +openaiImagesRouter.post( + "/v1/images/generations", + ipLimiter, + createPreprocessorMiddleware({ + inApi: "openai-image", + outApi: "openai-image", + service: "openai", + }), + openaiImagesProxy +); +openaiImagesRouter.post( + "/v1/chat/completions", + ipLimiter, + createPreprocessorMiddleware({ + inApi: "openai", + outApi: "openai-image", + service: "openai", + }), + openaiImagesProxy +); +export const openaiImage = openaiImagesRouter; diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 68f78ce..e779d40 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -2,61 +2,50 @@ import { RequestHandler, Router } from "express"; import { createProxyMiddleware } from "http-proxy-middleware"; import { config } from "../config"; import { keyPool } from "../shared/key-management"; -import { - ModelFamily, - OpenAIModelFamily, - getOpenAIModelFamily, -} from "../shared/models"; +import { getOpenAIModelFamily, ModelFamily, OpenAIModelFamily } from "../shared/models"; import { logger } from "../logger"; import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit"; import { handleProxyError } from "./middleware/common"; import { - RequestPreprocessor, addKey, addKeyForEmbeddingsRequest, applyQuotaLimits, blockZoomerOrigins, createEmbeddingsPreprocessorMiddleware, + createOnProxyReqHandler, createPreprocessorMiddleware, finalizeBody, forceModel, limitCompletions, + RequestPreprocessor, stripHeaders, - createOnProxyReqHandler, } from "./middleware/request"; -import { - createOnProxyResHandler, - ProxyResHandlerWithBody, -} from "./middleware/response"; +import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response"; + +// https://platform.openai.com/docs/models/overview +const KNOWN_MODELS = [ + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-0613", + "gpt-4-0314", // EOL 2024-06-13 + "gpt-4-32k", + "gpt-4-32k-0613", + "gpt-4-32k-0314", // EOL 2024-06-13 + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", // EOL 2024-06-13 + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-instruct", + "gpt-3.5-turbo-instruct-0914", + "text-embedding-ada-002", +]; let modelsCache: any = null; let modelsCacheTime = 0; -function getModelsResponse() { - if (new Date().getTime() - modelsCacheTime < 1000 * 60) { - return modelsCache; - } - - // https://platform.openai.com/docs/models/overview - const knownModels = [ - "gpt-4-1106-preview", - "gpt-4", - "gpt-4-0613", - "gpt-4-0314", // EOL 2024-06-13 - "gpt-4-32k", - "gpt-4-32k-0613", - "gpt-4-32k-0314", // EOL 2024-06-13 - "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", // EOL 2024-06-13 - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k-0613", - "gpt-3.5-turbo-instruct", - "gpt-3.5-turbo-instruct-0914", - "text-embedding-ada-002", - ]; - +export function generateModelList(models = KNOWN_MODELS) { let available = new Set(); for (const key of keyPool.list()) { if (key.isDisabled || key.service !== "openai") continue; @@ -67,7 +56,7 @@ function getModelsResponse() { const allowed = new Set(config.allowedModelFamilies); available = new Set([...available].filter((x) => allowed.has(x))); - const models = knownModels + return models .map((id) => ({ id, object: "model", @@ -87,15 +76,14 @@ function getModelsResponse() { parent: null, })) .filter((model) => available.has(getOpenAIModelFamily(model.id))); - - modelsCache = { object: "list", data: models }; - modelsCacheTime = new Date().getTime(); - - return modelsCache; } const handleModelRequest: RequestHandler = (_req, res) => { - res.status(200).json(getModelsResponse()); + if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache; + const result = generateModelList(); + modelsCache = { object: "list", data: result }; + modelsCacheTime = new Date().getTime(); + res.status(200).json(modelsCache); }; /** Handles some turbo-instruct special cases. */ @@ -137,9 +125,8 @@ const openaiResponseHandler: ProxyResHandlerWithBody = async ( body = transformTurboInstructResponse(body); } - // TODO: Remove once tokenization is stable - if (req.debug) { - body.proxy_tokenizer_debug_info = req.debug; + if (req.tokenizerInfo) { + body.proxy_tokenizer = req.tokenizerInfo; } res.status(200).json(body); diff --git a/src/proxy/palm.ts b/src/proxy/palm.ts index 0c9d67b..0137fd3 100644 --- a/src/proxy/palm.ts +++ b/src/proxy/palm.ts @@ -75,9 +75,8 @@ const palmResponseHandler: ProxyResHandlerWithBody = async ( body = transformPalmResponse(body, req); } - // TODO: Remove once tokenization is stable - if (req.debug) { - body.proxy_tokenizer_debug_info = req.debug; + if (req.tokenizerInfo) { + body.proxy_tokenizer = req.tokenizerInfo; } // TODO: PaLM has no streaming capability which will pose a problem here if diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index f0aefab..0c92a72 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -12,11 +12,11 @@ */ import type { Handler, Request } from "express"; -import { keyPool, SupportedModel } from "../shared/key-management"; +import { keyPool } from "../shared/key-management"; import { getClaudeModelFamily, getGooglePalmModelFamily, - getOpenAIModelFamily, + getOpenAIModelFamily, MODEL_FAMILIES, ModelFamily, } from "../shared/models"; import { buildFakeSse, initializeSseStream } from "../shared/streaming"; @@ -132,7 +132,7 @@ function getPartitionForRequest(req: Request): ModelFamily { // There is a single request queue, but it is partitioned by model family. // Model families are typically separated on cost/rate limit boundaries so // they should be treated as separate queues. - const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo"; + const model = req.body.model ?? "gpt-3.5-turbo"; // Weird special case for AWS because they serve multiple models from // different vendors, even if currently only one is supported. @@ -145,6 +145,7 @@ function getPartitionForRequest(req: Request): ModelFamily { return getClaudeModelFamily(model); case "openai": case "openai-text": + case "openai-image": return getOpenAIModelFamily(model); case "google-palm": return getGooglePalmModelFamily(model); @@ -207,40 +208,15 @@ export function dequeue(partition: ModelFamily): Request | undefined { function processQueue() { // This isn't completely correct, because a key can service multiple models. // Currently if a key is locked out on one model it will also stop servicing - // the others, because we only track one rate limit per key. - - // TODO: `getLockoutPeriod` uses model names instead of model families - // TODO: genericize this it's really ugly - const gpt4TurboLockout = keyPool.getLockoutPeriod("gpt-4-1106"); - const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k"); - const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4"); - const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo"); - const claudeLockout = keyPool.getLockoutPeriod("claude-v1"); - const palmLockout = keyPool.getLockoutPeriod("text-bison-001"); - const awsClaudeLockout = keyPool.getLockoutPeriod("anthropic.claude-v2"); + // the others, because we only track rate limits for the key as a whole. const reqs: (Request | undefined)[] = []; - if (gpt4TurboLockout === 0) { - reqs.push(dequeue("gpt4-turbo")); - } - if (gpt432kLockout === 0) { - reqs.push(dequeue("gpt4-32k")); - } - if (gpt4Lockout === 0) { - reqs.push(dequeue("gpt4")); - } - if (turboLockout === 0) { - reqs.push(dequeue("turbo")); - } - if (claudeLockout === 0) { - reqs.push(dequeue("claude")); - } - if (palmLockout === 0) { - reqs.push(dequeue("bison")); - } - if (awsClaudeLockout === 0) { - reqs.push(dequeue("aws-claude")); - } + MODEL_FAMILIES.forEach((modelFamily) => { + const lockout = keyPool.getLockoutPeriod(modelFamily); + if (lockout === 0) { + reqs.push(dequeue(modelFamily)); + } + }); reqs.filter(Boolean).forEach((req) => { if (req?.proceed) { diff --git a/src/proxy/rate-limit.ts b/src/proxy/rate-limit.ts index 5ee0251..c8acf8f 100644 --- a/src/proxy/rate-limit.ts +++ b/src/proxy/rate-limit.ts @@ -9,8 +9,6 @@ export const SHARED_IP_ADDRESSES = new Set([ "209.97.162.44", ]); -const RATE_LIMIT_ENABLED = Boolean(config.modelRateLimit); -const RATE_LIMIT = Math.max(1, config.modelRateLimit); const ONE_MINUTE_MS = 60 * 1000; type Timestamp = number; @@ -22,12 +20,15 @@ const exemptedRequests: Timestamp[] = []; const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) => attempt > now - ONE_MINUTE_MS; -const getTryAgainInMs = (ip: string) => { +const getTryAgainInMs = (ip: string, type: "text" | "image") => { const now = Date.now(); const attempts = lastAttempts.get(ip) || []; const validAttempts = attempts.filter(isRecentAttempt(now)); - if (validAttempts.length >= RATE_LIMIT) { + const limit = + type === "text" ? config.textModelRateLimit : config.imageModelRateLimit; + + if (validAttempts.length >= limit) { return validAttempts[0] - now + ONE_MINUTE_MS; } else { lastAttempts.set(ip, [...validAttempts, now]); @@ -35,12 +36,16 @@ const getTryAgainInMs = (ip: string) => { } }; -const getStatus = (ip: string) => { +const getStatus = (ip: string, type: "text" | "image") => { const now = Date.now(); const attempts = lastAttempts.get(ip) || []; const validAttempts = attempts.filter(isRecentAttempt(now)); + + const limit = + type === "text" ? config.textModelRateLimit : config.imageModelRateLimit; + return { - remaining: Math.max(0, RATE_LIMIT - validAttempts.length), + remaining: Math.max(0, limit - validAttempts.length), reset: validAttempts.length > 0 ? validAttempts[0] + ONE_MINUTE_MS : now, }; }; @@ -69,12 +74,26 @@ setInterval(clearOldExemptions, 10 * 1000); export const getUniqueIps = () => lastAttempts.size; +/** + * Can be used to manually remove the most recent attempt from an IP address, + * ie. in case a prompt triggered OpenAI's content filter and therefore did not + * result in a generation. + */ +export const refundLastAttempt = (req: Request) => { + const key = req.user?.token || req.risuToken || req.ip; + const attempts = lastAttempts.get(key) || []; + attempts.pop(); +} + export const ipLimiter = async ( req: Request, res: Response, next: NextFunction ) => { - if (!RATE_LIMIT_ENABLED) return next(); + const imageLimit = config.imageModelRateLimit; + const textLimit = config.textModelRateLimit; + + if (!textLimit && !imageLimit) return next(); if (req.user?.type === "special") return next(); // Exempts Agnai.chat from IP-based rate limiting because its IPs are shared @@ -90,24 +109,25 @@ export const ipLimiter = async ( return next(); } + const type = req.baseUrl + req.path ? "image" : "text"; + const limit = type === "image" ? imageLimit : textLimit; + // If user is authenticated, key rate limiting by their token. Otherwise, key // rate limiting by their IP address. Mitigates key sharing. const rateLimitKey = req.user?.token || req.risuToken || req.ip; - const { remaining, reset } = getStatus(rateLimitKey); - res.set("X-RateLimit-Limit", config.modelRateLimit.toString()); + const { remaining, reset } = getStatus(rateLimitKey, type); + res.set("X-RateLimit-Limit", limit.toString()); res.set("X-RateLimit-Remaining", remaining.toString()); res.set("X-RateLimit-Reset", reset.toString()); - const tryAgainInMs = getTryAgainInMs(rateLimitKey); + const tryAgainInMs = getTryAgainInMs(rateLimitKey, type); if (tryAgainInMs > 0) { res.set("Retry-After", tryAgainInMs.toString()); res.status(429).json({ error: { type: "proxy_rate_limited", - message: `This proxy is rate limited to ${ - config.modelRateLimit - } prompts per minute. Please try again in ${Math.ceil( + message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil( tryAgainInMs / 1000 )} seconds.`, }, diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index 424f104..dc76048 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -2,6 +2,7 @@ import express, { Request, Response, NextFunction } from "express"; import { gatekeeper } from "./gatekeeper"; import { checkRisuToken } from "./check-risu-token"; import { openai } from "./openai"; +import { openaiImage } from "./openai-image"; import { anthropic } from "./anthropic"; import { googlePalm } from "./palm"; import { aws } from "./aws"; @@ -27,6 +28,7 @@ proxyRouter.use((req, _res, next) => { next(); }); proxyRouter.use("/openai", addV1, openai); +proxyRouter.use("/openai-image", addV1, openaiImage); proxyRouter.use("/anthropic", addV1, anthropic); proxyRouter.use("/google-palm", addV1, googlePalm); proxyRouter.use("/aws/claude", addV1, aws); diff --git a/src/server.ts b/src/server.ts index 55aa95f..2b9e708 100644 --- a/src/server.ts +++ b/src/server.ts @@ -1,11 +1,14 @@ -import { assertConfigIsValid, config } from "./config"; +import { assertConfigIsValid, config, USER_ASSETS_DIR } from "./config"; import "source-map-support/register"; +import checkDiskSpace from "check-disk-space"; import express from "express"; import cors from "cors"; import path from "path"; import pinoHttp from "pino-http"; +import os from "os"; import childProcess from "child_process"; import { logger } from "./logger"; +import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir"; import { keyPool } from "./shared/key-management"; import { adminRouter } from "./admin/routes"; import { proxyRouter } from "./proxy/routes"; @@ -58,6 +61,8 @@ app.set("views", [ path.join(__dirname, "shared/views"), ]); +app.use("/user_content", express.static(USER_ASSETS_DIR)); + app.get("/health", (_req, res) => res.sendStatus(200)); app.use(cors()); app.use(checkOrigin); @@ -99,13 +104,17 @@ async function start() { await initTokenizers(); + if (config.allowedModelFamilies.includes("dall-e")) { + await setupAssetsDir(); + } + if (config.gatekeeper === "user_token") { await initUserStore(); } if (config.promptLogging) { logger.info("Starting prompt logging..."); - logQueue.start(); + await logQueue.start(); } logger.info("Starting request queue..."); @@ -116,8 +125,12 @@ async function start() { registerUncaughtExceptionHandler(); }); + const diskSpace = await checkDiskSpace( + __dirname.startsWith("/app") ? "/app" : os.homedir() + ); + logger.info( - { build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV }, + { build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV, diskSpace }, "Startup complete." ); } diff --git a/src/shared/file-storage/image-history.ts b/src/shared/file-storage/image-history.ts new file mode 100644 index 0000000..6e8c87a --- /dev/null +++ b/src/shared/file-storage/image-history.ts @@ -0,0 +1,35 @@ + +type ImageHistory = { + url: string; + prompt: string; +} + +const IMAGE_HISTORY_SIZE = 30; +const imageHistory = new Array(IMAGE_HISTORY_SIZE); +let imageHistoryIndex = 0; + +export function getImageHistory() { + return imageHistory.filter((url) => url); +} + +export function addToImageHistory(image: ImageHistory) { + imageHistory[imageHistoryIndex] = image; + imageHistoryIndex = (imageHistoryIndex + 1) % IMAGE_HISTORY_SIZE; +} + +export function getLastNImages(n: number) { + const result: ImageHistory[] = []; + let currentIndex = (imageHistoryIndex - 1 + IMAGE_HISTORY_SIZE) % IMAGE_HISTORY_SIZE; + + for (let i = 0; i < n; i++) { + // Check if the current index is valid (not undefined). + if (imageHistory[currentIndex]) { + result.unshift(imageHistory[currentIndex]); + } + + // Move to the previous item, wrapping around if necessary. + currentIndex = (currentIndex - 1 + IMAGE_HISTORY_SIZE) % IMAGE_HISTORY_SIZE; + } + + return result; +} diff --git a/src/shared/file-storage/mirror-generated-image.ts b/src/shared/file-storage/mirror-generated-image.ts new file mode 100644 index 0000000..74fd93a --- /dev/null +++ b/src/shared/file-storage/mirror-generated-image.ts @@ -0,0 +1,75 @@ +import axios from "axios"; +import { promises as fs } from "fs"; +import path from "path"; +import { v4 } from "uuid"; +import { USER_ASSETS_DIR } from "../../config"; +import { logger } from "../../logger"; +import { addToImageHistory } from "./image-history"; +import sharp from "sharp"; + +const log = logger.child({ module: "file-storage" }); + +export type OpenAIImageGenerationResult = { + created: number; + data: { + revised_prompt?: string; + url: string; + b64_json: string; + }[]; +}; + +async function downloadImage(url: string) { + const { data } = await axios.get(url, { responseType: "arraybuffer" }); + const buffer = Buffer.from(data, "binary"); + const newFilename = `${v4()}.png`; + + const filepath = path.join(USER_ASSETS_DIR, newFilename); + await fs.writeFile(filepath, buffer); + return filepath; +} + +async function saveB64Image(b64: string) { + const buffer = Buffer.from(b64, "base64"); + const newFilename = `${v4()}.png`; + + const filepath = path.join(USER_ASSETS_DIR, newFilename); + await fs.writeFile(filepath, buffer); + return filepath; +} + +async function createThumbnail(filepath: string) { + const thumbnailPath = filepath.replace(/(\.[\wd_-]+)$/i, "_t.jpg"); + + await sharp(filepath) + .resize(150, 150, { + fit: "inside", + withoutEnlargement: true, + }) + .toFormat("jpeg") + .toFile(thumbnailPath); + + return thumbnailPath; +} + +/** + * Downloads generated images and mirrors them to the user_content directory. + * Mutates the result object. + */ +export async function mirrorGeneratedImage( + host: string, + prompt: string, + result: OpenAIImageGenerationResult +): Promise { + for (const item of result.data) { + let mirror: string; + if (item.b64_json) { + mirror = await saveB64Image(item.b64_json); + } else { + mirror = await downloadImage(item.url); + } + item.url = `${host}/user_content/${path.basename(mirror)}`; + await createThumbnail(mirror); + addToImageHistory({ url: item.url, prompt }); + } + return result; +} diff --git a/src/shared/file-storage/setup-assets-dir.ts b/src/shared/file-storage/setup-assets-dir.ts new file mode 100644 index 0000000..2958810 --- /dev/null +++ b/src/shared/file-storage/setup-assets-dir.ts @@ -0,0 +1,20 @@ +import { promises as fs } from "fs"; +import { logger } from "../../logger"; +import { USER_ASSETS_DIR } from "../../config"; + +const log = logger.child({ module: "file-storage" }); + +export async function setupAssetsDir() { + try { + log.info({ dir: USER_ASSETS_DIR }, "Setting up user assets directory"); + await fs.mkdir(USER_ASSETS_DIR, { recursive: true }); + const stats = await fs.stat(USER_ASSETS_DIR); + const mode = stats.mode | 0o666; + if (stats.mode !== mode) { + await fs.chmod(USER_ASSETS_DIR, mode); + } + } catch (e) { + log.error(e); + throw new Error("Could not create user assets directory for DALL-E image generation. You may need to update your Dockerfile to `chown` the working directory to user 1000. See the proxy docs for more information."); + } +} diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts index a50131f..7cebe03 100644 --- a/src/shared/key-management/anthropic/provider.ts +++ b/src/shared/key-management/anthropic/provider.ts @@ -6,14 +6,12 @@ import type { AnthropicModelFamily } from "../../models"; import { AnthropicKeyChecker } from "./checker"; // https://docs.anthropic.com/claude/reference/selecting-a-model -export const ANTHROPIC_SUPPORTED_MODELS = [ - "claude-instant-v1", - "claude-instant-v1-100k", - "claude-v1", - "claude-v1-100k", - "claude-2", -] as const; -export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number]; +export type AnthropicModel = + | "claude-instant-v1" + | "claude-instant-v1-100k" + | "claude-v1" + | "claude-v1-100k" + | "claude-2"; export type AnthropicKeyUpdate = Omit< Partial, @@ -180,7 +178,7 @@ export class AnthropicKeyProvider implements KeyProvider { key.claudeTokens += tokens; } - public getLockoutPeriod(_model: AnthropicModel) { + public getLockoutPeriod() { const activeKeys = this.keys.filter((k) => !k.isDisabled); // Don't lock out if there are no keys available or the queue will stall. // Just let it through so the add-key middleware can throw an error. diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index 31db723..96ceed5 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -6,12 +6,10 @@ import type { AwsBedrockModelFamily } from "../../models"; import { AwsKeyChecker } from "./checker"; // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html -export const AWS_BEDROCK_SUPPORTED_MODELS = [ - "anthropic.claude-v1", - "anthropic.claude-v2", - "anthropic.claude-instant-v1", -] as const; -export type AwsBedrockModel = (typeof AWS_BEDROCK_SUPPORTED_MODELS)[number]; +export type AwsBedrockModel = + | "anthropic.claude-v1" + | "anthropic.claude-v2" + | "anthropic.claude-instant-v1"; type AwsBedrockKeyUsage = { [K in AwsBedrockModelFamily as `${K}Tokens`]: number; @@ -158,7 +156,7 @@ export class AwsBedrockKeyProvider implements KeyProvider { key["aws-claudeTokens"] += tokens; } - public getLockoutPeriod(_model: AwsBedrockModel) { + public getLockoutPeriod() { // TODO: same exact behavior for three providers, should be refactored const activeKeys = this.keys.filter((k) => !k.isDisabled); // Don't lock out if there are no keys available or the queue will stall. diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 3fe4597..c647a14 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -1,15 +1,17 @@ -import { OPENAI_SUPPORTED_MODELS, OpenAIModel } from "./openai/provider"; -import { - ANTHROPIC_SUPPORTED_MODELS, - AnthropicModel, -} from "./anthropic/provider"; -import { GOOGLE_PALM_SUPPORTED_MODELS, GooglePalmModel } from "./palm/provider"; -import { AWS_BEDROCK_SUPPORTED_MODELS, AwsBedrockModel } from "./aws/provider"; +import { OpenAIModel } from "./openai/provider"; +import { AnthropicModel } from "./anthropic/provider"; +import { GooglePalmModel } from "./palm/provider"; +import { AwsBedrockModel } from "./aws/provider"; import { KeyPool } from "./key-pool"; import type { ModelFamily } from "../models"; /** The request and response format used by a model's API. */ -export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text"; +export type APIFormat = + | "openai" + | "anthropic" + | "google-palm" + | "openai-text" + | "openai-image"; /** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */ export type LLMService = "openai" | "anthropic" | "google-palm" | "aws"; export type Model = @@ -60,23 +62,12 @@ export interface KeyProvider { update(hash: string, update: Partial): void; available(): number; incrementUsage(hash: string, model: string, tokens: number): void; - getLockoutPeriod(model: Model): number; + getLockoutPeriod(model: ModelFamily): number; markRateLimited(hash: string): void; recheck(): void; } export const keyPool = new KeyPool(); -export const SUPPORTED_MODELS = [ - ...OPENAI_SUPPORTED_MODELS, - ...ANTHROPIC_SUPPORTED_MODELS, -] as const; -export type SupportedModel = (typeof SUPPORTED_MODELS)[number]; -export { - OPENAI_SUPPORTED_MODELS, - ANTHROPIC_SUPPORTED_MODELS, - GOOGLE_PALM_SUPPORTED_MODELS, - AWS_BEDROCK_SUPPORTED_MODELS, -}; export { AnthropicKey } from "./anthropic/provider"; export { OpenAIKey } from "./openai/provider"; export { GooglePalmKey } from "./palm/provider"; diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 07a7b4c..95e7f26 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -9,6 +9,8 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { GooglePalmKeyProvider } from "./palm/provider"; import { AwsBedrockKeyProvider } from "./aws/provider"; +import { ModelFamily } from "../models"; +import { assertNever } from "../utils"; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; @@ -37,7 +39,7 @@ export class KeyPool { } public get(model: Model): Key { - const service = this.getService(model); + const service = this.getServiceForModel(model); return this.getKeyProvider(service).get(model); } @@ -67,7 +69,7 @@ export class KeyPool { public available(model: Model | "all" = "all"): number { return this.keyProviders.reduce((sum, provider) => { const includeProvider = - model === "all" || this.getService(model) === provider.service; + model === "all" || this.getServiceForModel(model) === provider.service; return sum + (includeProvider ? provider.available() : 0); }, 0); } @@ -77,9 +79,9 @@ export class KeyPool { provider.incrementUsage(key.hash, model, tokens); } - public getLockoutPeriod(model: Model): number { - const service = this.getService(model); - return this.getKeyProvider(service).getLockoutPeriod(model); + public getLockoutPeriod(family: ModelFamily): number { + const service = this.getServiceForModelFamily(family); + return this.getKeyProvider(service).getLockoutPeriod(family); } public markRateLimited(key: Key): void { @@ -104,8 +106,12 @@ export class KeyPool { provider.recheck(); } - private getService(model: Model): LLMService { - if (model.startsWith("gpt") || model.startsWith("text-embedding-ada")) { + private getServiceForModel(model: Model): LLMService { + if ( + model.startsWith("gpt") || + model.startsWith("text-embedding-ada") || + model.startsWith("dall-e") + ) { // https://platform.openai.com/docs/models/model-endpoint-compatibility return "openai"; } else if (model.startsWith("claude-")) { @@ -122,6 +128,25 @@ export class KeyPool { throw new Error(`Unknown service for model '${model}'`); } + private getServiceForModelFamily(modelFamily: ModelFamily): LLMService { + switch (modelFamily) { + case "gpt4": + case "gpt4-32k": + case "gpt4-turbo": + case "turbo": + case "dall-e": + return "openai"; + case "claude": + return "anthropic"; + case "bison": + return "google-palm"; + case "aws-claude": + return "aws"; + default: + assertNever(modelFamily); + } + } + private getKeyProvider(service: LLMService): KeyProvider { return this.keyProviders.find((provider) => provider.service === service)!; } diff --git a/src/shared/key-management/openai/checker.ts b/src/shared/key-management/openai/checker.ts index e7f28b8..88d7b59 100644 --- a/src/shared/key-management/openai/checker.ts +++ b/src/shared/key-management/openai/checker.ts @@ -95,10 +95,15 @@ export class OpenAIKeyChecker extends KeyCheckerBase { const { data } = await axios.get(GET_MODELS_URL, opts); const models = data.data; - // const families: OpenAIModelFamily[] = []; const families = new Set(); models.forEach(({ id }) => families.add(getOpenAIModelFamily(id, "turbo"))); + // For now we remove dall-e from the list of provisioned models if only + // dall-e-2 is available. + if (families.has("dall-e") && !models.find(({ id }) => id === "dall-e-3")) { + families.delete("dall-e"); + } + // We want to update the key's model families here, but we don't want to // update its `lastChecked` timestamp because we need to let the liveness // check run before we can consider the key checked. diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 4e527dd..d98bd39 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -15,12 +15,9 @@ export type OpenAIModel = | "gpt-4" | "gpt-4-32k" | "gpt-4-1106" - | "text-embedding-ada-002"; -export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [ - "gpt-3.5-turbo", - "gpt-3.5-turbo-instruct", - "gpt-4", -] as const; + | "text-embedding-ada-002" + | "dall-e-2" + | "dall-e-3" // Flattening model families instead of using a nested object for easier // cloning. @@ -127,6 +124,7 @@ export class OpenAIKeyProvider implements KeyProvider { gpt4Tokens: 0, "gpt4-32kTokens": 0, "gpt4-turboTokens": 0, + "dall-eTokens": 0, gpt4Rpm: 0, }; this.keys.push(newKey); @@ -284,10 +282,9 @@ export class OpenAIKeyProvider implements KeyProvider { * Given a model, returns the period until a key will be available to service * the request, or returns 0 if a key is ready immediately. */ - public getLockoutPeriod(model: Model = "gpt-4"): number { - const neededFamily = getOpenAIModelFamily(model); + public getLockoutPeriod(family: OpenAIModelFamily): number { const activeKeys = this.keys.filter( - (key) => !key.isDisabled && key.modelFamilies.includes(neededFamily) + (key) => !key.isDisabled && key.modelFamilies.includes(family) ); if (activeKeys.length === 0) { @@ -335,6 +332,10 @@ export class OpenAIKeyProvider implements KeyProvider { this.log.debug({ key: keyHash }, "Key rate limited"); const key = this.keys.find((k) => k.hash === keyHash)!; key.rateLimitedAt = Date.now(); + // DALL-E requests do not send headers telling us when the rate limit will + // be reset so we need to set a fallback value here. Other models will have + // this overwritten by the `updateRateLimits` method. + key.rateLimitRequestsReset = 5000; } public incrementUsage(keyHash: string, model: string, tokens: number) { diff --git a/src/shared/key-management/palm/provider.ts b/src/shared/key-management/palm/provider.ts index 5b3e3d7..dccfa08 100644 --- a/src/shared/key-management/palm/provider.ts +++ b/src/shared/key-management/palm/provider.ts @@ -5,11 +5,7 @@ import { logger } from "../../../logger"; import type { GooglePalmModelFamily } from "../../models"; // https://developers.generativeai.google.com/models/language -export const GOOGLE_PALM_SUPPORTED_MODELS = [ - "text-bison-001", - // "chat-bison-001", no adjustable safety settings, so it's useless -] as const; -export type GooglePalmModel = (typeof GOOGLE_PALM_SUPPORTED_MODELS)[number]; +export type GooglePalmModel = "text-bison-001"; export type GooglePalmKeyUpdate = Omit< Partial, @@ -149,7 +145,7 @@ export class GooglePalmKeyProvider implements KeyProvider { key.bisonTokens += tokens; } - public getLockoutPeriod(_model: GooglePalmModel) { + public getLockoutPeriod() { const activeKeys = this.keys.filter((k) => !k.isDisabled); // Don't lock out if there are no keys available or the queue will stall. // Just let it through so the add-key middleware can throw an error. diff --git a/src/shared/models.ts b/src/shared/models.ts index f89a266..861c151 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -1,6 +1,8 @@ -import { logger } from "../logger"; +// Don't import anything here, this is imported by config.ts -export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo"; +import pino from "pino"; + +export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e"; export type AnthropicModelFamily = "claude"; export type GooglePalmModelFamily = "bison"; export type AwsBedrockModelFamily = "aws-claude"; @@ -17,6 +19,7 @@ export const MODEL_FAMILIES = (( "gpt4", "gpt4-32k", "gpt4-turbo", + "dall-e", "claude", "bison", "aws-claude", @@ -30,8 +33,11 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { "^gpt-4$": "gpt4", "^gpt-3.5-turbo": "turbo", "^text-embedding-ada-002$": "turbo", + "^dall-e-\\d{1}$": "dall-e", }; +const modelLogger = pino({ level: "debug" }).child({ module: "startup" }); + export function getOpenAIModelFamily( model: string, defaultFamily: OpenAIModelFamily = "gpt4" @@ -42,14 +48,14 @@ export function getOpenAIModelFamily( return defaultFamily; } -export function getClaudeModelFamily(_model: string): ModelFamily { +export function getClaudeModelFamily(model: string): ModelFamily { + if (model.startsWith("anthropic.")) return getAwsBedrockModelFamily(model); return "claude"; } export function getGooglePalmModelFamily(model: string): ModelFamily { if (model.match(/^\w+-bison-\d{3}$/)) return "bison"; - const stack = new Error().stack; - logger.warn({ model, stack }, "Unmapped PaLM model family"); + modelLogger.warn({ model }, "Could not determine Google PaLM model family"); return "bison"; } diff --git a/src/shared/prompt-logging/backends/sheets.ts b/src/shared/prompt-logging/backends/sheets.ts index e8c60e1..6b2450f 100644 --- a/src/shared/prompt-logging/backends/sheets.ts +++ b/src/shared/prompt-logging/backends/sheets.ts @@ -396,7 +396,7 @@ export const init = async (onStop: () => void) => { await loadIndexSheet(false); await writeIndexSheet(); } catch (e) { - log.info("Creating new index sheet."); + log.warn(e, "Could not load index sheet. Creating a new one."); await createIndexSheet(); } }; diff --git a/src/shared/stats.ts b/src/shared/stats.ts index aeb816e..7729a3c 100644 --- a/src/shared/stats.ts +++ b/src/shared/stats.ts @@ -17,6 +17,9 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) { case "turbo": cost = 0.000001; break; + case "dall-e": + cost = 0.00001; + break; case "aws-claude": case "claude": cost = 0.00001102; diff --git a/src/shared/streaming.ts b/src/shared/streaming.ts index 6611521..f3943b2 100644 --- a/src/shared/streaming.ts +++ b/src/shared/streaming.ts @@ -79,7 +79,8 @@ export function buildFakeSse( }; break; case "google-palm": - throw new Error("PaLM not supported as an inbound API format"); + case "openai-image": + throw new Error(`SSE not supported for ${req.inboundApi} requests`); default: assertNever(req.inboundApi); } @@ -92,4 +93,4 @@ export function buildFakeSse( } return `data: ${JSON.stringify(fakeEvent)}\n\n`; -} \ No newline at end of file +} diff --git a/src/shared/tokenization/openai.ts b/src/shared/tokenization/openai.ts index b71585d..a984299 100644 --- a/src/shared/tokenization/openai.ts +++ b/src/shared/tokenization/openai.ts @@ -78,3 +78,63 @@ export type OpenAIPromptMessage = { content: string; role: string; }; + +// Model Resolution Price +// DALL·E 3 1024×1024 $0.040 / image +// 1024×1792, 1792×1024 $0.080 / image +// DALL·E 3 HD 1024×1024 $0.080 / image +// 1024×1792, 1792×1024 $0.120 / image +// DALL·E 2 1024×1024 $0.020 / image +// 512×512 $0.018 / image +// 256×256 $0.016 / image + +export const DALLE_TOKENS_PER_DOLLAR = 100000; + +/** + * OpenAI image generation with DALL-E doesn't use tokens but everything else + * in the application does. There is a fixed cost for each image generation + * request depending on the model and selected quality/resolution parameters, + * which we convert to tokens at a rate of 100000 tokens per dollar. + */ +export function getOpenAIImageCost(params: { + model: "dall-e-2" | "dall-e-3"; + quality: "standard" | "hd"; + resolution: "512x512" | "256x256" | "1024x1024" | "1024x1792" | "1792x1024"; + n: number | null; +}) { + const { model, quality, resolution, n } = params; + const usd = (() => { + switch (model) { + case "dall-e-2": + switch (resolution) { + case "512x512": + return 0.018; + case "256x256": + return 0.016; + case "1024x1024": + return 0.02; + default: + throw new Error("Invalid resolution"); + } + case "dall-e-3": + switch (resolution) { + case "1024x1024": + return quality === "standard" ? 0.04 : 0.08; + case "1024x1792": + case "1792x1024": + return quality === "standard" ? 0.08 : 0.12; + default: + throw new Error("Invalid resolution"); + } + default: + throw new Error("Invalid image generation model"); + } + })(); + + const tokens = (n ?? 1) * (usd * DALLE_TOKENS_PER_DOLLAR); + + return { + tokenizer: `openai-image cost`, + token_count: Math.ceil(tokens), + }; +} diff --git a/src/shared/tokenization/tokenizer.ts b/src/shared/tokenization/tokenizer.ts index 6b5491c..9d30b4e 100644 --- a/src/shared/tokenization/tokenizer.ts +++ b/src/shared/tokenization/tokenizer.ts @@ -8,6 +8,7 @@ import { init as initOpenAi, getTokenCount as getOpenAITokenCount, OpenAIPromptMessage, + getOpenAIImageCost, } from "./openai"; import { APIFormat } from "../key-management"; @@ -26,6 +27,7 @@ type TokenCountRequest = { req: Request } & ( service: "openai-text" | "anthropic" | "google-palm"; } | { prompt?: never; completion: string; service: APIFormat } + | { prompt?: never; completion?: never; service: "openai-image" } ); type TokenCountResult = { @@ -53,6 +55,16 @@ export async function countTokens({ ...getOpenAITokenCount(prompt ?? completion, req.body.model), tokenization_duration_ms: getElapsedMs(time), }; + case "openai-image": + return { + ...getOpenAIImageCost({ + model: req.body.model, + quality: req.body.quality, + resolution: req.body.size, + n: parseInt(req.body.n, 10) || null, + }), + tokenization_duration_ms: getElapsedMs(time), + }; case "google-palm": // TODO: Can't find a tokenization library for PaLM. There is an API // endpoint for it but it adds significant latency to the request. diff --git a/src/shared/users/schema.ts b/src/shared/users/schema.ts index eb9df44..492b4b8 100644 --- a/src/shared/users/schema.ts +++ b/src/shared/users/schema.ts @@ -7,6 +7,7 @@ export const tokenCountsSchema: ZodType = z.object({ gpt4: z.number().optional().default(0), "gpt4-32k": z.number().optional().default(0), "gpt4-turbo": z.number().optional().default(0), + "dall-e": z.number().optional().default(0), claude: z.number().optional().default(0), bison: z.number().optional().default(0), "aws-claude": z.number().optional().default(0), diff --git a/src/shared/users/user-store.ts b/src/shared/users/user-store.ts index 556c513..8fbdb58 100644 --- a/src/shared/users/user-store.ts +++ b/src/shared/users/user-store.ts @@ -11,9 +11,17 @@ import admin from "firebase-admin"; import schedule from "node-schedule"; import { v4 as uuid } from "uuid"; import { config, getFirebaseApp } from "../../config"; -import { MODEL_FAMILIES, ModelFamily } from "../models"; +import { + getClaudeModelFamily, + getGooglePalmModelFamily, + getOpenAIModelFamily, + MODEL_FAMILIES, + ModelFamily, +} from "../models"; import { logger } from "../../logger"; import { User, UserTokenCounts, UserUpdate } from "./schema"; +import { APIFormat } from "../key-management"; +import { assertNever } from "../utils"; const log = logger.child({ module: "users" }); @@ -22,6 +30,7 @@ const INITIAL_TOKENS: Required = { gpt4: 0, "gpt4-32k": 0, "gpt4-turbo": 0, + "dall-e": 0, claude: 0, bison: 0, "aws-claude": 0, @@ -166,11 +175,12 @@ export function incrementPromptCount(token: string) { export function incrementTokenCount( token: string, model: string, + api: APIFormat, consumption: number ) { const user = users.get(token); if (!user) return; - const modelFamily = getModelFamilyForQuotaUsage(model); + const modelFamily = getModelFamilyForQuotaUsage(model, api); const existing = user.tokenCounts[modelFamily] ?? 0; user.tokenCounts[modelFamily] = existing + consumption; usersToFlush.add(token); @@ -181,9 +191,10 @@ export function incrementTokenCount( * to the user's list of IPs. Returns the user if they exist and are not * disabled, otherwise returns undefined. */ -export function authenticate(token: string, ip: string): - { user?: User; result: "success" | "disabled" | "not_found" | "limited" } - { +export function authenticate( + token: string, + ip: string +): { user?: User; result: "success" | "disabled" | "not_found" | "limited" } { const user = users.get(token); if (!user) return { result: "not_found" }; if (user.disabledAt) return { result: "disabled" }; @@ -210,16 +221,22 @@ export function authenticate(token: string, ip: string): return { user, result: "success" }; } -export function hasAvailableQuota( - token: string, - model: string, - requested: number -) { - const user = users.get(token); +export function hasAvailableQuota({ + userToken, + model, + api, + requested, +}: { + userToken: string; + model: string; + api: APIFormat; + requested: number; +}) { + const user = users.get(userToken); if (!user) return false; if (user.type === "special") return true; - const modelFamily = getModelFamilyForQuotaUsage(model); + const modelFamily = getModelFamilyForQuotaUsage(model, api); const { tokenCounts, tokenLimits } = user; const tokenLimit = tokenLimits[modelFamily]; @@ -361,30 +378,22 @@ async function flushUsers() { ); } -// TODO: use key-management/models.ts for family mapping -function getModelFamilyForQuotaUsage(model: string): ModelFamily { - if (model.startsWith("gpt-4-1106")) { - return "gpt4-turbo"; +function getModelFamilyForQuotaUsage( + model: string, + api: APIFormat +): ModelFamily { + switch (api) { + case "openai": + case "openai-text": + case "openai-image": + return getOpenAIModelFamily(model); + case "anthropic": + return getClaudeModelFamily(model); + case "google-palm": + return getGooglePalmModelFamily(model); + default: + assertNever(api); } - if (model.includes("32k")) { - return "gpt4-32k"; - } - if (model.startsWith("gpt-4")) { - return "gpt4"; - } - if (model.startsWith("gpt-3.5")) { - return "turbo"; - } - if (model.includes("bison")) { - return "bison"; - } - if (model.startsWith("claude")) { - return "claude"; - } - if (model.startsWith("anthropic.claude")) { - return "aws-claude"; - } - throw new Error(`Unknown quota model family for model ${model}`); } function getRefreshCrontab() { diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index 503bf76..bbafc84 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -24,8 +24,7 @@ declare global { heartbeatInterval?: NodeJS.Timeout; promptTokens?: number; outputTokens?: number; - // TODO: remove later - debug: Record; + tokenizerInfo: Record; signedRequest: HttpRequest; } }