OpenAI DALL-E Image Generation (khanon/oai-reverse-proxy!52)

This commit is contained in:
khanon 2023-11-14 05:41:19 +00:00
parent 3ea23760c3
commit 20c064394a
56 changed files with 1401 additions and 305 deletions

View File

@ -11,8 +11,10 @@
# The title displayed on the info page.
# SERVER_TITLE=Coom Tunnel
# Model requests allowed per minute per user.
# MODEL_RATE_LIMIT=4
# Text model requests allowed per minute per user.
# TEXT_MODEL_RATE_LIMIT=4
# Image model requests allowed per minute per user.
# IMAGE_MODEL_RATE_LIMIT=2
# Max number of context tokens a user can request at once.
# Increase this if your proxy allow GPT 32k or 128k context
@ -31,10 +33,11 @@
# CHECK_KEYS=true
# Which model types users are allowed to access.
# If you want to restrict access to certain models, uncomment the line below and list only the models you want to allow,
# separated by commas. By default, all models are allowed. The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | claude | bison | aws-claude
# ALLOWED_MODEL_FAMILIES=turbo,gpt4-turbo,aws-claude
# The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
# generation, uncomment the line below and add 'dall-e' to the list.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude
# URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com
@ -82,10 +85,18 @@
# ALLOW_NICKNAME_CHANGES=true
# Default token quotas for each model family. (0 for unlimited)
# DALL-E "tokens" are counted at a rate of 100000 tokens per US$1.00 generated,
# which is similar to the cost of GPT-4 Turbo.
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
# See `docs/dall-e-configuration.md` for more information.
# TOKEN_QUOTA_TURBO=0
# TOKEN_QUOTA_GPT4=0
# TOKEN_QUOTA_GPT4_32K=0
# TOKEN_QUOTA_GPT4_TURBO=0
# TOKEN_QUOTA_DALL_E=0
# TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_BISON=0
# TOKEN_QUOTA_AWS_CLAUDE=0
# How often to refresh token quotas. (hourly | daily)
# Leave unset to never automatically refresh quotas.

2
data/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitkeep

0
data/user-files/.gitkeep Normal file
View File

View File

@ -3,6 +3,8 @@ RUN apt-get update && \
apt-get install -y git
RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
WORKDIR /app
RUN chown -R 1000:1000 /app
USER 1000
RUN npm install
COPY Dockerfile greeting.md* .env* ./
RUN npm run build

View File

@ -17,6 +17,8 @@ ARG GREETING_URL
RUN if [ -n "$GREETING_URL" ]; then \
curl -sL "$GREETING_URL" > greeting.md; \
fi
RUN chown -R 1000:1000 /app
USER 1000
COPY package*.json greeting.md* ./
RUN npm install
COPY . .

View File

@ -0,0 +1,71 @@
# Configuring the proxy for DALL-E
The proxy supports DALL-E 2 and DALL-E 3 image generation via the `/proxy/openai-images` endpoint. By default it is disabled as it is somewhat expensive and potentially more open to abuse than text generation.
- [Updating your Dockerfile](#updating-your-dockerfile)
- [Enabling DALL-E](#enabling-dall-e)
- [Setting quotas](#setting-quotas)
- [Rate limiting](#rate-limiting)
## Updating your Dockerfile
If you are using a previous version of the Dockerfile supplied with the proxy, it doesn't have the necessary permissions to let the proxy save temporary files.
You can replace the entire thing with the new Dockerfile at [./docker/huggingface/Dockerfile](../docker/huggingface/Dockerfile) (or the equivalent for Render deployments).
You can also modify your existing Dockerfile; just add the following lines after the `WORKDIR` line:
```Dockerfile
# Existing
RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
WORKDIR /app
# Take ownership of the app directory and switch to the non-root user
RUN chown -R 1000:1000 /app
USER 1000
# Existing
RUN npm install
```
## Enabling DALL-E
Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL-E. For example:
```
# GPT3.5 Turbo, GPT-4, GPT-4 Turbo, and DALL-E
ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
# All models as of this writing
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e
```
Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
## Setting quotas
DALL-E doesn't bill by token like text generation models. Instead there is a fixed cost per image generated, depending on the model, image size, and selected quality.
The proxy still uses tokens to set quotas for users. The cost for each generated image will be converted to "tokens" at a rate of 100000 tokens per US$1.00. This works out to a similar cost-per-token as GPT-4 Turbo, so you can use similar token quotas for both.
Use `TOKEN_QUOTA_DALL_E` to set the default quota for image generation. Otherwise it works the same as token quotas for other models.
```
# ~50 standard DALL-E images per refresh period, or US$2.00
TOKEN_QUOTA_DALL_E=200000
```
Refer to [https://openai.com/pricing](https://openai.com/pricing) for the latest pricing information. As of this writing, the cheapest DALL-E 3 image costs $0.04 per generation, which works out to 4000 tokens. Higher resolution and quality settings can cost up to $0.12 per image, or 12000 tokens.
## Rate limiting
The old `MODEL_RATE_LIMIT` setting has been split into `TEXT_MODEL_RATE_LIMIT` and `IMAGE_MODEL_RATE_LIMIT`. Whatever value you previously set for `MODEL_RATE_LIMIT` will be used for text models.
If you don't specify a `IMAGE_MODEL_RATE_LIMIT`, it defaults to half of the `TEXT_MODEL_RATE_LIMIT`, to a minimum of 1 image per minute.
```
# 4 text generations per minute, 2 images per minute
TEXT_MODEL_RATE_LIMIT=4
IMAGE_MODEL_RATE_LIMIT=2
```
If a prompt is filtered by OpenAI's content filter, it won't count towards the rate limit.
## Hiding recent images
By default, the proxy shows the last 12 recently generated images by users. You can hide this section by setting `SHOW_RECENT_IMAGES` to `false`.

View File

@ -25,6 +25,8 @@ RUN apt-get update && \
apt-get install -y git
RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
WORKDIR /app
RUN chown -R 1000:1000 /app
USER 1000
RUN npm install
COPY Dockerfile greeting.md* .env* ./
RUN npm run build

406
package-lock.json generated
View File

@ -15,6 +15,7 @@
"@smithy/signature-v4": "^2.0.10",
"@smithy/types": "^2.3.4",
"axios": "^1.3.5",
"check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
"copyfiles": "^2.4.1",
"cors": "^2.8.5",
@ -33,6 +34,7 @@
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"sanitize-html": "^2.11.0",
"sharp": "^0.32.6",
"showdown": "^2.1.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",
@ -1373,15 +1375,20 @@
}
},
"node_modules/axios": {
"version": "1.3.5",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.3.5.tgz",
"integrity": "sha512-glL/PvG/E+xCWwV8S6nCHcrfg1exGx7vxyUIivIA1iL7BIh6bePylCfVHwp6k13ao7SATxB6imau2kqY+I67kw==",
"version": "1.6.1",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
"dependencies": {
"follow-redirects": "^1.15.0",
"form-data": "^4.0.0",
"proxy-from-env": "^1.1.0"
}
},
"node_modules/b4a": {
"version": "1.6.4",
"resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz",
"integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw=="
},
"node_modules/balanced-match": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
@ -1423,6 +1430,52 @@
"node": ">=8"
}
},
"node_modules/bl": {
"version": "4.1.0",
"resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz",
"integrity": "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==",
"dependencies": {
"buffer": "^5.5.0",
"inherits": "^2.0.4",
"readable-stream": "^3.4.0"
}
},
"node_modules/bl/node_modules/buffer": {
"version": "5.7.1",
"resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz",
"integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==",
"funding": [
{
"type": "github",
"url": "https://github.com/sponsors/feross"
},
{
"type": "patreon",
"url": "https://www.patreon.com/feross"
},
{
"type": "consulting",
"url": "https://feross.org/support"
}
],
"dependencies": {
"base64-js": "^1.3.1",
"ieee754": "^1.1.13"
}
},
"node_modules/bl/node_modules/readable-stream": {
"version": "3.6.2",
"resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz",
"integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==",
"dependencies": {
"inherits": "^2.0.3",
"string_decoder": "^1.1.1",
"util-deprecate": "^1.0.1"
},
"engines": {
"node": ">= 6"
}
},
"node_modules/bluebird": {
"version": "3.7.2",
"resolved": "https://registry.npmjs.org/bluebird/-/bluebird-3.7.2.tgz",
@ -1582,6 +1635,14 @@
"node": ">=8"
}
},
"node_modules/check-disk-space": {
"version": "3.4.0",
"resolved": "https://registry.npmjs.org/check-disk-space/-/check-disk-space-3.4.0.tgz",
"integrity": "sha512-drVkSqfwA+TvuEhFipiR1OC9boEGZL5RrWvVsOthdcvQNXyCCuKkEiTOTXZ7qxSf/GLwq4GvzfrQD/Wz325hgw==",
"engines": {
"node": ">=16"
}
},
"node_modules/chokidar": {
"version": "3.5.3",
"resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz",
@ -1609,6 +1670,11 @@
"fsevents": "~2.3.2"
}
},
"node_modules/chownr": {
"version": "1.1.4",
"resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz",
"integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg=="
},
"node_modules/cliui": {
"version": "8.0.1",
"resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz",
@ -1623,6 +1689,18 @@
"node": ">=12"
}
},
"node_modules/color": {
"version": "4.2.3",
"resolved": "https://registry.npmjs.org/color/-/color-4.2.3.tgz",
"integrity": "sha512-1rXeuUUiGGrykh+CeBdu5Ie7OJwinCgQY0bc7GCRxy5xVHy+moaqkpL/jqQq0MtQOeYcrqEz4abc5f0KtU7W4A==",
"dependencies": {
"color-convert": "^2.0.1",
"color-string": "^1.9.0"
},
"engines": {
"node": ">=12.5.0"
}
},
"node_modules/color-convert": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz",
@ -1639,6 +1717,15 @@
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz",
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA=="
},
"node_modules/color-string": {
"version": "1.9.1",
"resolved": "https://registry.npmjs.org/color-string/-/color-string-1.9.1.tgz",
"integrity": "sha512-shrVawQFojnZv6xM40anx4CkoDP+fZsw/ZerEMsW/pyzsRbElpsL/DBVW7q3ExxwusdNXI3lXpuhEZkzs8p5Eg==",
"dependencies": {
"color-name": "^1.0.0",
"simple-swizzle": "^0.2.2"
}
},
"node_modules/colorette": {
"version": "2.0.20",
"resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz",
@ -2000,6 +2087,28 @@
"ms": "2.0.0"
}
},
"node_modules/decompress-response": {
"version": "6.0.0",
"resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-6.0.0.tgz",
"integrity": "sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==",
"dependencies": {
"mimic-response": "^3.1.0"
},
"engines": {
"node": ">=10"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/deep-extend": {
"version": "0.6.0",
"resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz",
"integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==",
"engines": {
"node": ">=4.0.0"
}
},
"node_modules/deep-is": {
"version": "0.1.4",
"resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz",
@ -2039,6 +2148,14 @@
"npm": "1.2.8000 || >= 1.4.16"
}
},
"node_modules/detect-libc": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.2.tgz",
"integrity": "sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==",
"engines": {
"node": ">=8"
}
},
"node_modules/diff": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
@ -2188,7 +2305,6 @@
"version": "1.4.4",
"resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz",
"integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==",
"devOptional": true,
"dependencies": {
"once": "^1.4.0"
}
@ -2473,6 +2589,14 @@
"node": ">=0.8.x"
}
},
"node_modules/expand-template": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz",
"integrity": "sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==",
"engines": {
"node": ">=6"
}
},
"node_modules/express": {
"version": "4.18.2",
"resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz",
@ -2557,6 +2681,11 @@
"integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==",
"optional": true
},
"node_modules/fast-fifo": {
"version": "1.3.2",
"resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz",
"integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ=="
},
"node_modules/fast-levenshtein": {
"version": "2.0.6",
"resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz",
@ -2718,6 +2847,11 @@
"node": ">= 0.6"
}
},
"node_modules/fs-constants": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz",
"integrity": "sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow=="
},
"node_modules/fs.realpath": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz",
@ -2795,6 +2929,11 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/github-from-package": {
"version": "0.0.0",
"resolved": "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz",
"integrity": "sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw=="
},
"node_modules/glob": {
"version": "8.1.0",
"resolved": "https://registry.npmjs.org/glob/-/glob-8.1.0.tgz",
@ -3246,6 +3385,11 @@
"resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz",
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
},
"node_modules/ini": {
"version": "1.3.8",
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
"integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
},
"node_modules/ipaddr.js": {
"version": "1.9.1",
"resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz",
@ -3254,6 +3398,11 @@
"node": ">= 0.10"
}
},
"node_modules/is-arrayish": {
"version": "0.3.2",
"resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.3.2.tgz",
"integrity": "sha512-eVRqCvVlZbuw3GrM63ovNSNAeA1K16kaR/LRY/92w0zxQ5/1YzwblUX652i4Xs9RwAGjW9d9y6X88t8OaAJfWQ=="
},
"node_modules/is-binary-path": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz",
@ -3780,6 +3929,17 @@
"node": ">= 0.6"
}
},
"node_modules/mimic-response": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-3.1.0.tgz",
"integrity": "sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==",
"engines": {
"node": ">=10"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
@ -3810,6 +3970,11 @@
"node": ">=10"
}
},
"node_modules/mkdirp-classic": {
"version": "0.5.3",
"resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz",
"integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A=="
},
"node_modules/ms": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz",
@ -3860,6 +4025,11 @@
"node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1"
}
},
"node_modules/napi-build-utils": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/napi-build-utils/-/napi-build-utils-1.0.2.tgz",
"integrity": "sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg=="
},
"node_modules/negotiator": {
"version": "0.6.3",
"resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz",
@ -3868,6 +4038,22 @@
"node": ">= 0.6"
}
},
"node_modules/node-abi": {
"version": "3.51.0",
"resolved": "https://registry.npmjs.org/node-abi/-/node-abi-3.51.0.tgz",
"integrity": "sha512-SQkEP4hmNWjlniS5zdnfIXTk1x7Ome85RDzHlTbBtzE97Gfwz/Ipw4v/Ryk20DWIy3yCNVLVlGKApCnmvYoJbA==",
"dependencies": {
"semver": "^7.3.5"
},
"engines": {
"node": ">=10"
}
},
"node_modules/node-addon-api": {
"version": "6.1.0",
"resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz",
"integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA=="
},
"node_modules/node-fetch": {
"version": "2.6.9",
"resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.9.tgz",
@ -4217,6 +4403,70 @@
"node": "^10 || ^12 || >=14"
}
},
"node_modules/prebuild-install": {
"version": "7.1.1",
"resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.1.tgz",
"integrity": "sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==",
"dependencies": {
"detect-libc": "^2.0.0",
"expand-template": "^2.0.3",
"github-from-package": "0.0.0",
"minimist": "^1.2.3",
"mkdirp-classic": "^0.5.3",
"napi-build-utils": "^1.0.1",
"node-abi": "^3.3.0",
"pump": "^3.0.0",
"rc": "^1.2.7",
"simple-get": "^4.0.0",
"tar-fs": "^2.0.0",
"tunnel-agent": "^0.6.0"
},
"bin": {
"prebuild-install": "bin.js"
},
"engines": {
"node": ">=10"
}
},
"node_modules/prebuild-install/node_modules/readable-stream": {
"version": "3.6.2",
"resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz",
"integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==",
"dependencies": {
"inherits": "^2.0.3",
"string_decoder": "^1.1.1",
"util-deprecate": "^1.0.1"
},
"engines": {
"node": ">= 6"
}
},
"node_modules/prebuild-install/node_modules/tar-fs": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz",
"integrity": "sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==",
"dependencies": {
"chownr": "^1.1.1",
"mkdirp-classic": "^0.5.2",
"pump": "^3.0.0",
"tar-stream": "^2.1.4"
}
},
"node_modules/prebuild-install/node_modules/tar-stream": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-2.2.0.tgz",
"integrity": "sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==",
"dependencies": {
"bl": "^4.0.3",
"end-of-stream": "^1.4.1",
"fs-constants": "^1.0.0",
"inherits": "^2.0.3",
"readable-stream": "^3.1.1"
},
"engines": {
"node": ">=6"
}
},
"node_modules/prettier": {
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz",
@ -4352,7 +4602,6 @@
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz",
"integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==",
"dev": true,
"dependencies": {
"end-of-stream": "^1.1.0",
"once": "^1.3.1"
@ -4372,6 +4621,11 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/queue-tick": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz",
"integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag=="
},
"node_modules/quick-format-unescaped": {
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/quick-format-unescaped/-/quick-format-unescaped-4.0.4.tgz",
@ -4407,6 +4661,28 @@
"node": ">= 0.8"
}
},
"node_modules/rc": {
"version": "1.2.8",
"resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz",
"integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==",
"dependencies": {
"deep-extend": "^0.6.0",
"ini": "~1.3.0",
"minimist": "^1.2.0",
"strip-json-comments": "~2.0.1"
},
"bin": {
"rc": "cli.js"
}
},
"node_modules/rc/node_modules/strip-json-comments": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz",
"integrity": "sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==",
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/readable-stream": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.3.0.tgz",
@ -4615,9 +4891,9 @@
"dev": true
},
"node_modules/semver": {
"version": "7.5.3",
"resolved": "https://registry.npmjs.org/semver/-/semver-7.5.3.tgz",
"integrity": "sha512-QBlUtyVk/5EeHbi7X0fw6liDZc7BBmEaSYn01fMU1OUYbf6GPsbTtd8WmnqbI20SeycoHSeiybkE/q1Q+qlThQ==",
"version": "7.5.4",
"resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
"integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
"dependencies": {
"lru-cache": "^6.0.0"
},
@ -4675,6 +4951,28 @@
"resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz",
"integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw=="
},
"node_modules/sharp": {
"version": "0.32.6",
"resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz",
"integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==",
"hasInstallScript": true,
"dependencies": {
"color": "^4.2.3",
"detect-libc": "^2.0.2",
"node-addon-api": "^6.1.0",
"prebuild-install": "^7.1.1",
"semver": "^7.5.4",
"simple-get": "^4.0.1",
"tar-fs": "^3.0.4",
"tunnel-agent": "^0.6.0"
},
"engines": {
"node": ">=14.15.0"
},
"funding": {
"url": "https://opencollective.com/libvips"
}
},
"node_modules/shell-quote": {
"version": "1.8.1",
"resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz",
@ -4712,6 +5010,57 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/simple-concat": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/simple-concat/-/simple-concat-1.0.1.tgz",
"integrity": "sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==",
"funding": [
{
"type": "github",
"url": "https://github.com/sponsors/feross"
},
{
"type": "patreon",
"url": "https://www.patreon.com/feross"
},
{
"type": "consulting",
"url": "https://feross.org/support"
}
]
},
"node_modules/simple-get": {
"version": "4.0.1",
"resolved": "https://registry.npmjs.org/simple-get/-/simple-get-4.0.1.tgz",
"integrity": "sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==",
"funding": [
{
"type": "github",
"url": "https://github.com/sponsors/feross"
},
{
"type": "patreon",
"url": "https://www.patreon.com/feross"
},
{
"type": "consulting",
"url": "https://feross.org/support"
}
],
"dependencies": {
"decompress-response": "^6.0.0",
"once": "^1.3.1",
"simple-concat": "^1.0.0"
}
},
"node_modules/simple-swizzle": {
"version": "0.2.2",
"resolved": "https://registry.npmjs.org/simple-swizzle/-/simple-swizzle-0.2.2.tgz",
"integrity": "sha512-JA//kQgZtbuY83m+xT+tXJkmJncGMTFT+C+g2h2R9uxkYIrE2yy9sgmcLhCnw57/WSD+Eh3J97FPEDFnbXnDUg==",
"dependencies": {
"is-arrayish": "^0.3.1"
}
},
"node_modules/simple-update-notifier": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/simple-update-notifier/-/simple-update-notifier-2.0.0.tgz",
@ -4809,11 +5158,19 @@
"node": ">=10.0.0"
}
},
"node_modules/streamx": {
"version": "2.15.4",
"resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.4.tgz",
"integrity": "sha512-uSXKl88bibiUCQ1eMpItRljCzDENcDx18rsfDmV79r0e/ThfrAwxG4Y2FarQZ2G4/21xcOKmFFd1Hue+ZIDwHw==",
"dependencies": {
"fast-fifo": "^1.1.0",
"queue-tick": "^1.0.1"
}
},
"node_modules/string_decoder": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz",
"integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==",
"devOptional": true,
"dependencies": {
"safe-buffer": "~5.2.0"
}
@ -4872,6 +5229,26 @@
"node": ">=4"
}
},
"node_modules/tar-fs": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz",
"integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==",
"dependencies": {
"mkdirp-classic": "^0.5.2",
"pump": "^3.0.0",
"tar-stream": "^3.1.5"
}
},
"node_modules/tar-stream": {
"version": "3.1.6",
"resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz",
"integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==",
"dependencies": {
"b4a": "^1.6.4",
"fast-fifo": "^1.2.0",
"streamx": "^2.15.0"
}
},
"node_modules/teeny-request": {
"version": "8.0.3",
"resolved": "https://registry.npmjs.org/teeny-request/-/teeny-request-8.0.3.tgz",
@ -5047,6 +5424,17 @@
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q=="
},
"node_modules/tunnel-agent": {
"version": "0.6.0",
"resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz",
"integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==",
"dependencies": {
"safe-buffer": "^5.0.1"
},
"engines": {
"node": "*"
}
},
"node_modules/type-is": {
"version": "1.6.18",
"resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz",

View File

@ -23,6 +23,7 @@
"@smithy/signature-v4": "^2.0.10",
"@smithy/types": "^2.3.4",
"axios": "^1.3.5",
"check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
"copyfiles": "^2.4.1",
"cors": "^2.8.5",
@ -41,6 +42,7 @@
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"sanitize-html": "^2.11.0",
"sharp": "^0.32.6",
"showdown": "^2.1.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",

View File

@ -1,12 +1,17 @@
import dotenv from "dotenv";
import type firebase from "firebase-admin";
import path from "path";
import pino from "pino";
import type { ModelFamily } from "./shared/models";
import { MODEL_FAMILIES } from "./shared/models";
dotenv.config();
const startupLogger = pino({ level: "debug" }).child({ module: "startup" });
const isDev = process.env.NODE_ENV !== "production";
export const DATA_DIR = path.join(__dirname, "..", "data");
export const USER_ASSETS_DIR = path.join(DATA_DIR, "user-files");
type Config = {
/** The port the proxy server will listen on. */
port: number;
@ -75,8 +80,10 @@ type Config = {
* `maxIpsPerUser` limit, or if only connections from new IPs are be rejected.
*/
maxIpsAutoBan: boolean;
/** Per-IP limit for requests per minute to OpenAI's completions endpoint. */
modelRateLimit: number;
/** Per-IP limit for requests per minute to text and chat models. */
textModelRateLimit: number;
/** Per-IP limit for requests per minute to image generation models. */
imageModelRateLimit: number;
/**
* For OpenAI, the maximum number of context tokens (prompt + max output) a
* user can request before their request is rejected.
@ -157,6 +164,8 @@ type Config = {
quotaRefreshPeriod?: "hourly" | "daily" | string;
/** Whether to allow users to change their own nicknames via the UI. */
allowNicknameChanges: boolean;
/** Whether to show recent DALL-E image generations on the homepage. */
showRecentImages: boolean;
/**
* If true, cookies will be set without the `Secure` attribute, allowing
* the admin UI to used over HTTP.
@ -180,7 +189,8 @@ export const config: Config = {
maxIpsAutoBan: getEnvWithDefault("MAX_IPS_AUTO_BAN", true),
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
textModelRateLimit: getEnvWithDefault("TEXT_MODEL_RATE_LIMIT", 4),
imageModelRateLimit: getEnvWithDefault("IMAGE_MODEL_RATE_LIMIT", 4),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 16384),
maxContextTokensAnthropic: getEnvWithDefault(
"MAX_CONTEXT_TOKENS_ANTHROPIC",
@ -225,17 +235,19 @@ export const config: Config = {
"You must be over the age of majority in your country to use this service."
),
blockRedirect: getEnvWithDefault("BLOCK_REDIRECT", "https://www.9gag.com"),
tokenQuota: {
turbo: getEnvWithDefault("TOKEN_QUOTA_TURBO", 0),
gpt4: getEnvWithDefault("TOKEN_QUOTA_GPT4", 0),
"gpt4-32k": getEnvWithDefault("TOKEN_QUOTA_GPT4_32K", 0),
"gpt4-turbo": getEnvWithDefault("TOKEN_QUOTA_GPT4_TURBO", 0),
claude: getEnvWithDefault("TOKEN_QUOTA_CLAUDE", 0),
bison: getEnvWithDefault("TOKEN_QUOTA_BISON", 0),
"aws-claude": getEnvWithDefault("TOKEN_QUOTA_AWS_CLAUDE", 0),
},
tokenQuota: MODEL_FAMILIES.reduce(
(acc, family: ModelFamily) => {
acc[family] = getEnvWithDefault(
`TOKEN_QUOTA_${family.toUpperCase().replace(/-/g, "_")}`,
0
) as number;
return acc;
},
{} as { [key in ModelFamily]: number }
),
quotaRefreshPeriod: getEnvWithDefault("QUOTA_REFRESH_PERIOD", undefined),
allowNicknameChanges: getEnvWithDefault("ALLOW_NICKNAME_CHANGES", true),
showRecentImages: getEnvWithDefault("SHOW_RECENT_IMAGES", true),
useInsecureCookies: getEnvWithDefault("USE_INSECURE_COOKIES", isDev),
} as const;
@ -252,6 +264,19 @@ function generateCookieSecret() {
export const COOKIE_SECRET = generateCookieSecret();
export async function assertConfigIsValid() {
if (process.env.MODEL_RATE_LIMIT !== undefined) {
const limit =
parseInt(process.env.MODEL_RATE_LIMIT, 10) || config.textModelRateLimit;
config.textModelRateLimit = limit;
config.imageModelRateLimit = Math.max(Math.floor(limit / 2), 1);
startupLogger.warn(
{ textLimit: limit, imageLimit: config.imageModelRateLimit },
"MODEL_RATE_LIMIT is deprecated. Use TEXT_MODEL_RATE_LIMIT and IMAGE_MODEL_RATE_LIMIT instead."
);
}
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
throw new Error(
`Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.`
@ -332,6 +357,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"blockMessage",
"blockRedirect",
"allowNicknameChanges",
"showRecentImages",
"useInsecureCookies",
];
@ -428,5 +454,5 @@ function parseCsv(val: string): string[] {
const regex = /(".*?"|[^",]+)(?=\s*,|\s*$)/g;
const matches = val.match(regex) || [];
return matches.map(item => item.replace(/^"|"$/g, '').trim());
return matches.map((item) => item.replace(/^"|"$/g, "").trim());
}

View File

@ -14,6 +14,7 @@ import { getUniqueIps } from "./proxy/rate-limit";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
import { assertNever } from "./shared/utils";
import { getLastNImages } from "./shared/file-storage/image-history";
const INFO_PAGE_TTL = 2000;
let infoPageHtml: string | undefined;
@ -94,6 +95,8 @@ function cacheInfoPageHtml(baseUrl: string) {
const tokens = serviceStats.get("tokens") || 0;
const tokenCost = serviceStats.get("tokenCost") || 0;
const allowDalle = config.allowedModelFamilies.includes("dall-e");
const info = {
uptime: Math.floor(process.uptime()),
endpoints: {
@ -101,13 +104,16 @@ function cacheInfoPageHtml(baseUrl: string) {
...(openaiKeys
? { ["openai2"]: baseUrl + "/proxy/openai/turbo-instruct" }
: {}),
...(openaiKeys && allowDalle
? { ["openai-image"]: baseUrl + "/proxy/openai-image" }
: {}),
...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}),
...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}),
...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}),
},
proompts,
tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`,
...(config.modelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
openaiKeys,
anthropicKeys,
palmKeys,
@ -287,7 +293,6 @@ function getOpenAIInfo() {
// Don't show trial/revoked keys for non-turbo families.
// Generally those stats only make sense for the lowest-tier model.
if (f !== "turbo") {
console.log("deleting", f);
delete info[f]!.trialKeys;
delete info[f]!.revokedKeys;
}
@ -457,6 +462,9 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t
if (customGreeting) {
infoBody += `\n## Server Greeting\n${customGreeting}`;
}
infoBody += buildRecentImageSection();
return converter.makeHtml(infoBody);
}
@ -499,6 +507,43 @@ function getServerTitle() {
return "OAI Reverse Proxy";
}
function buildRecentImageSection() {
if (
!config.allowedModelFamilies.includes("dall-e") ||
!config.showRecentImages
) {
return "";
}
let html = `<h2>Recent DALL-E Generations</h2>`;
const recentImages = getLastNImages(12).reverse();
if (recentImages.length === 0) {
html += `<p>No images yet.</p>`;
return html;
}
html += `<div style="display: flex; flex-wrap: wrap;" id="recent-images">`;
for (const { url, prompt } of recentImages) {
const thumbUrl = url.replace(/\.png$/, "_t.jpg");
const escapedPrompt = escapeHtml(prompt);
html += `<div style="margin: 0.5em;" class="recent-image">
<a href="${url}" target="_blank"><img src="${thumbUrl}" title="${escapedPrompt}" alt="${escapedPrompt}" style="max-width: 150px; max-height: 150px;" /></a>
</div>`;
}
html += `</div>`;
return html;
}
function escapeHtml(unsafe: string) {
return unsafe
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#39;');
}
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
// Huggingface broke their amazon elb config and no longer sends the
// x-forwarded-host header. This is a workaround.

View File

@ -87,9 +87,8 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
body = transformAnthropicResponse(body, req);
}
// TODO: Remove once tokenization is stable
if (req.debug) {
body.proxy_tokenizer_debug_info = req.debug;
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);

View File

@ -73,9 +73,8 @@ const awsResponseHandler: ProxyResHandlerWithBody = async (
body = transformAwsResponse(body, req);
}
// TODO: Remove once tokenization is stable
if (req.debug) {
body.proxy_tokenizer_debug_info = req.debug;
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
// AWS does not confirm the model in the response, so we have to add it

View File

@ -9,11 +9,10 @@ import { QuotaExceededError } from "./request/apply-quota-limits";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions";
const OPENAI_EMBEDDINGS_ENDPOINT = "/v1/embeddings";
const OPENAI_IMAGE_COMPLETION_ENDPOINT = "/v1/images/generations";
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
/** Returns true if we're making a request to a completion endpoint. */
export function isCompletionRequest(req: Request) {
// 99% sure this function is not needed anymore
export function isTextGenerationRequest(req: Request) {
return (
req.method === "POST" &&
[
@ -24,6 +23,13 @@ export function isCompletionRequest(req: Request) {
);
}
export function isImageGenerationRequest(req: Request) {
return (
req.method === "POST" &&
req.path.startsWith(OPENAI_IMAGE_COMPLETION_ENDPOINT)
);
}
export function isEmbeddingsRequest(req: Request) {
return (
req.method === "POST" && req.path.startsWith(OPENAI_EMBEDDINGS_ENDPOINT)
@ -53,8 +59,8 @@ export function writeErrorResponse(
res.write(`data: [DONE]\n\n`);
res.end();
} else {
if (req.debug && errorPayload.error) {
errorPayload.error.proxy_tokenizer_debug_info = req.debug;
if (req.tokenizerInfo && errorPayload.error) {
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
}
res.status(statusCode).json(errorPayload);
}
@ -103,7 +109,7 @@ function classifyError(err: Error): {
code: { enabled: false },
maxErrors: 3,
transform: ({ issue, ...rest }) => {
return `At '${rest.pathComponent}', ${issue.message}`;
return `At '${rest.pathComponent}': ${issue.message}`;
},
});
return { status: 400, userMessage, type: "proxy_validation_error" };
@ -173,6 +179,8 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
return body.completion.trim();
case "google-palm":
return body.candidates[0].output;
case "openai-image":
return body.data?.map((item: any) => item.url).join("\n");
default:
assertNever(format);
}
@ -184,6 +192,8 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
case "openai":
case "openai-text":
return body.model;
case "openai-image":
return req.body.model;
case "anthropic":
// Anthropic confirms the model in the response, but AWS Claude doesn't.
return body.model || req.body.model;

View File

@ -1,5 +1,5 @@
import { AnthropicKey, Key } from "../../../shared/key-management";
import { isCompletionRequest } from "../common";
import { isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
/**
@ -11,7 +11,7 @@ export const addAnthropicPreamble: ProxyRequestMiddleware = (
_proxyReq,
req
) => {
if (!isCompletionRequest(req) || req.key?.service !== "anthropic") {
if (!isTextGenerationRequest(req) || req.key?.service !== "anthropic") {
return;
}

View File

@ -1,5 +1,5 @@
import { Key, OpenAIKey, keyPool } from "../../../shared/key-management";
import { isCompletionRequest, isEmbeddingsRequest } from "../common";
import { isEmbeddingsRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
import { assertNever } from "../../../shared/utils";
@ -7,18 +7,6 @@ import { assertNever } from "../../../shared/utils";
export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
let assignedKey: Key;
if (!isCompletionRequest(req)) {
// Horrible, horrible hack to stop the proxy from complaining about clients
// not sending a model when they are requesting the list of models (which
// requires a key, but obviously not a model).
// I don't think this is needed anymore since models requests are no longer
// proxied to the upstream API. Everything going through this is either a
// completion request or a special case like OpenAI embeddings.
req.log.warn({ path: req.path }, "addKey called on non-completion request");
req.body.model = "gpt-3.5-turbo";
}
if (!req.inboundApi || !req.outboundApi) {
const err = new Error(
"Request API format missing. Did you forget to add the request preprocessor to your router?"
@ -54,6 +42,9 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
throw new Error(
"OpenAI Chat as an API translation target is not supported"
);
case "openai-image":
assignedKey = keyPool.get("dall-e-3");
break;
default:
assertNever(req.outboundApi);
}

View File

@ -1,5 +1,5 @@
import { hasAvailableQuota } from "../../../shared/users/user-store";
import { isCompletionRequest } from "../common";
import { isImageGenerationRequest, isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
export class QuotaExceededError extends Error {
@ -12,12 +12,19 @@ export class QuotaExceededError extends Error {
}
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
if (!isCompletionRequest(req) || !req.user) {
return;
}
const subjectToQuota =
isTextGenerationRequest(req) || isImageGenerationRequest(req);
if (!subjectToQuota || !req.user) return;
const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
if (!hasAvailableQuota(req.user.token, req.body.model, requestedTokens)) {
if (
!hasAvailableQuota({
userToken: req.user.token,
model: req.body.model,
api: req.outboundApi,
requested: requestedTokens,
})
) {
throw new QuotaExceededError(
"You have exceeded your proxy token quota for this model.",
{

View File

@ -1,4 +1,3 @@
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
@ -15,10 +14,6 @@ class ForbiddenError extends Error {
* stop getting emails asking for tech support.
*/
export const blockZoomerOrigins: ProxyRequestMiddleware = (_proxyReq, req) => {
if (!isCompletionRequest(req)) {
return;
}
const origin = req.headers.origin || req.headers.referer;
if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) {
// Venus-derivatives send a test prompt to check if the proxy is working.

View File

@ -35,14 +35,18 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service });
break;
}
case "openai-image": {
req.outputTokens = 1;
result = await countTokens({ req, service });
break;
}
default:
assertNever(service);
}
req.promptTokens = result.token_count;
// TODO: Remove once token counting is stable
req.log.debug({ result: result }, "Counted prompt tokens.");
req.debug = req.debug ?? {};
req.debug = { ...req.debug, ...result };
};
req.tokenizerInfo = req.tokenizerInfo ?? {};
req.tokenizerInfo = { ...req.tokenizerInfo, ...result };
};

View File

@ -4,6 +4,11 @@ import type { ProxyRequestMiddleware } from ".";
/** Finalize the rewritten request body. Must be the last rewriter. */
export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => {
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
// For image generation requests, remove stream flag.
if (req.outboundApi === "openai-image") {
delete req.body.stream;
}
const updatedBody = JSON.stringify(req.body);
proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody));
(req as any).rawBody = Buffer.from(updatedBody);

View File

@ -58,6 +58,7 @@ function getPromptFromRequest(req: Request) {
)
.join("\n\n");
case "openai-text":
case "openai-image":
return body.prompt;
case "google-palm":
return body.prompt.text;

View File

@ -1,12 +1,12 @@
import { isCompletionRequest } from "../common";
import { isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
/**
* Don't allow multiple completions to be requested to prevent abuse.
* Don't allow multiple text completions to be requested to prevent abuse.
* OpenAI-only, Anthropic provides no such parameter.
**/
export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => {
if (isCompletionRequest(req) && req.outboundApi === "openai") {
if (isTextGenerationRequest(req) && req.outboundApi === "openai") {
const originalN = req.body?.n || 1;
req.body.n = 1;
if (originalN !== req.body.n) {

View File

@ -17,6 +17,7 @@ export const createOnProxyReqHandler = ({
// The streaming flag must be set before any other middleware runs, because
// it may influence which other middleware a particular API pipeline wants
// to run.
// Image generation requests can't be streamed.
req.isStreaming = req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming;

View File

@ -2,7 +2,7 @@ import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { OpenAIPromptMessage } from "../../../shared/tokenization";
import { isCompletionRequest } from "../common";
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
import { RequestPreprocessor } from ".";
import { APIFormat } from "../../../shared/key-management";
@ -88,6 +88,21 @@ const OpenAIV1TextCompletionSchema = z
})
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true }));
// https://platform.openai.com/docs/api-reference/images/create
const OpenAIV1ImagesGenerationSchema = z.object({
prompt: z.string().max(4000),
model: z.string().optional(),
quality: z.enum(["standard", "hd"]).optional().default("standard"),
n: z.number().int().min(1).max(4).optional().default(1),
response_format: z.enum(["url", "b64_json"]).optional(),
size: z
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
.optional()
.default("1024x1024"),
style: z.enum(["vivid", "natural"]).optional().default("vivid"),
user: z.string().optional(),
});
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText
const PalmV1GenerateTextSchema = z.object({
model: z.string(),
@ -110,6 +125,7 @@ const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
anthropic: AnthropicV1CompleteSchema,
openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-palm": PalmV1GenerateTextSchema,
};
@ -117,11 +133,10 @@ const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
const sameService = req.inboundApi === req.outboundApi;
const alreadyTransformed = req.retryCount > 0;
const notTransformable = !isCompletionRequest(req);
const notTransformable =
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
if (alreadyTransformed || notTransformable) {
return;
}
if (alreadyTransformed || notTransformable) return;
if (sameService) {
const result = VALIDATORS[req.inboundApi].safeParse(req.body);
@ -151,6 +166,11 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "openai-image") {
req.body = openaiToOpenaiImage(req);
return;
}
throw new Error(
`'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.`
);
@ -226,6 +246,49 @@ function openaiToOpenaiText(req: Request) {
return OpenAIV1TextCompletionSchema.parse(transformed);
}
// Takes the last chat message and uses it verbatim as the image prompt.
function openaiToOpenaiImage(req: Request) {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-OpenAI-image request"
);
throw result.error;
}
const { messages } = result.data;
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
if (body.stream) {
throw new Error(
"Streaming is not supported for image generation requests."
);
}
// Some frontends do weird things with the prompt, like prefixing it with a
// character name or wrapping the entire thing in quotes. We will look for
// the index of "Image:" and use everything after that as the prompt.
const index = prompt?.toLowerCase().indexOf("image:");
if (index === -1 || !prompt) {
throw new Error(
`Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).`
);
}
// TODO: Add some way to specify parameters via chat message
const transformed = {
model: body.model.includes("dall-e") ? body.model : "dall-e-3",
quality: "standard",
size: "1024x1024",
response_format: "url",
prompt: prompt.slice(index! + 6).trim(),
};
return OpenAIV1ImagesGenerationSchema.parse(transformed);
}
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({

View File

@ -34,6 +34,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
case "google-palm":
proxyMax = BISON_MAX_CONTEXT;
break;
case "openai-image":
return;
default:
assertNever(req.outboundApi);
}
@ -81,10 +83,10 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
"Prompt size validated"
);
req.debug.prompt_tokens = promptTokens;
req.debug.completion_tokens = outputTokens;
req.debug.max_model_tokens = modelMax;
req.debug.max_proxy_tokens = proxyMax;
req.tokenizerInfo.prompt_tokens = promptTokens;
req.tokenizerInfo.completion_tokens = outputTokens;
req.tokenizerInfo.max_model_tokens = modelMax;
req.tokenizerInfo.max_proxy_tokens = proxyMax;
};
function assertRequestHasTokenCounts(

View File

@ -13,13 +13,16 @@ import {
incrementTokenCount,
} from "../../../shared/users/user-store";
import { assertNever } from "../../../shared/utils";
import { refundLastAttempt } from "../../rate-limit";
import {
getCompletionFromBody,
isCompletionRequest,
isImageGenerationRequest,
isTextGenerationRequest,
writeErrorResponse,
} from "../common";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
import { saveImage } from "./save-image";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
@ -106,6 +109,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
countResponseTokens,
incrementUsage,
copyHttpHeaders,
saveImage,
logPrompt,
...apiMiddleware
);
@ -285,7 +289,16 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
switch (service) {
case "openai":
case "google-palm":
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
if (errorPayload.error?.code === "content_policy_violation") {
errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`;
refundLastAttempt(req);
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
// For some reason, some models return this 400 error instead of the
// same 429 billing error that other models return.
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
} else {
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
}
break;
case "anthropic":
case "aws":
@ -453,6 +466,7 @@ function handleOpenAIRateLimitError(
const type = errorPayload.error?.type;
switch (type) {
case "insufficient_quota":
case "invalid_request_error": // this is the billing_hard_limit_reached error seen in some cases
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
@ -487,13 +501,22 @@ function handleOpenAIRateLimitError(
}
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (isCompletionRequest(req)) {
if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
const model = req.body.model;
const tokensUsed = req.promptTokens! + req.outputTokens!;
req.log.debug(
{
model,
tokensUsed,
promptTokens: req.promptTokens,
outputTokens: req.outputTokens,
},
`Incrementing usage for model`
);
keyPool.incrementUsage(req.key!, model, tokensUsed);
if (req.user) {
incrementPromptCount(req.user.token);
incrementTokenCount(req.user.token, model, tokensUsed);
incrementTokenCount(req.user.token, model, req.outboundApi, tokensUsed);
}
}
};
@ -504,6 +527,12 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
_res,
body
) => {
if (req.outboundApi === "openai-image") {
req.outputTokens = req.promptTokens;
req.promptTokens = 0;
return;
}
// This function is prone to breaking if the upstream API makes even minor
// changes to the response format, especially for SSE responses. If you're
// seeing errors in this function, check the reassembled response body from
@ -518,8 +547,8 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
{ service, tokens, prevOutputTokens: req.outputTokens },
`Counted tokens for completion`
);
if (req.debug) {
req.debug.completion_tokens = tokens;
if (req.tokenizerInfo) {
req.tokenizerInfo.completion_tokens = tokens;
}
req.outputTokens = tokens.token_count;

View File

@ -4,7 +4,8 @@ import { logQueue } from "../../../shared/prompt-logging";
import {
getCompletionFromBody,
getModelFromBody,
isCompletionRequest,
isImageGenerationRequest,
isTextGenerationRequest,
} from "../common";
import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
@ -23,11 +24,11 @@ export const logPrompt: ProxyResHandlerWithBody = async (
throw new Error("Expected body to be an object");
}
if (!isCompletionRequest(req)) {
return;
}
const loggable =
isTextGenerationRequest(req) || isImageGenerationRequest(req);
if (!loggable) return;
const promptPayload = getPromptForRequest(req);
const promptPayload = getPromptForRequest(req, responseBody);
const promptFlattened = flattenMessages(promptPayload);
const response = getCompletionFromBody(req, responseBody);
const model = getModelFromBody(req, responseBody);
@ -46,7 +47,18 @@ type OaiMessage = {
content: string;
};
const getPromptForRequest = (req: Request): string | OaiMessage[] => {
type OaiImageResult = {
prompt: string;
size: string;
style: string;
quality: string;
revisedPrompt?: string;
};
const getPromptForRequest = (
req: Request,
responseBody: Record<string, any>
): string | OaiMessage[] | OaiImageResult => {
// Since the prompt logger only runs after the request has been proxied, we
// can assume the body has already been transformed to the target API's
// format.
@ -55,6 +67,14 @@ const getPromptForRequest = (req: Request): string | OaiMessage[] => {
return req.body.messages;
case "openai-text":
return req.body.prompt;
case "openai-image":
return {
prompt: req.body.prompt,
size: req.body.size,
style: req.body.style,
quality: req.body.quality,
revisedPrompt: responseBody.data[0].revised_prompt,
};
case "anthropic":
return req.body.prompt;
case "google-palm":
@ -64,9 +84,14 @@ const getPromptForRequest = (req: Request): string | OaiMessage[] => {
}
};
const flattenMessages = (messages: string | OaiMessage[]): string => {
if (typeof messages === "string") {
return messages.trim();
const flattenMessages = (
val: string | OaiMessage[] | OaiImageResult
): string => {
if (typeof val === "string") {
return val.trim();
}
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
if (Array.isArray(val)) {
return val.map((m) => `${m.role}: ${m.content}`).join("\n");
}
return val.prompt.trim();
};

View File

@ -0,0 +1,27 @@
import { ProxyResHandlerWithBody } from "./index";
import { mirrorGeneratedImage, OpenAIImageGenerationResult } from "../../../shared/file-storage/mirror-generated-image";
export const saveImage: ProxyResHandlerWithBody = async (
_proxyRes,
req,
_res,
body,
) => {
if (req.outboundApi !== "openai-image") {
return;
}
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (body.data) {
const baseUrl = req.protocol + "://" + req.get("host");
const prompt = body.data[0].revised_prompt ?? req.body.prompt;
await mirrorGeneratedImage(
baseUrl,
prompt,
body as OpenAIImageGenerationResult
);
}
};

View File

@ -33,9 +33,10 @@ export class EventAggregator {
case "anthropic":
return mergeEventsForAnthropic(this.events);
case "google-palm":
throw new Error("Google PaLM API does not support streaming responses");
case "openai-image":
throw new Error(`SSE aggregation not supported for ${this.format}`);
default:
assertNever(this.format);
}
}
}
}

View File

@ -99,7 +99,8 @@ function getTransformer(
? anthropicV1ToOpenAI
: anthropicV2ToOpenAI;
case "google-palm":
throw new Error("Google PaLM does not support streaming responses");
case "openai-image":
throw new Error(`SSE transformation not supported for ${responseApi}`);
default:
assertNever(responseApi);
}

153
src/proxy/openai-image.ts Normal file
View File

@ -0,0 +1,153 @@
import { RequestHandler, Router, Request } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
applyQuotaLimits,
blockZoomerOrigins,
createPreprocessorMiddleware,
finalizeBody,
stripHeaders,
createOnProxyReqHandler,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { generateModelList } from "./openai";
import {
mirrorGeneratedImage,
OpenAIImageGenerationResult,
} from "../shared/file-storage/mirror-generated-image";
const KNOWN_MODELS = ["dall-e-2", "dall-e-3"];
let modelListCache: any = null;
let modelListValid = 0;
const handleModelRequest: RequestHandler = (_req, res) => {
if (new Date().getTime() - modelListValid < 1000 * 60) return modelListCache;
const result = generateModelList(KNOWN_MODELS);
modelListCache = { object: "list", data: result };
modelListValid = new Date().getTime();
res.status(200).json(modelListCache);
};
const openaiImagesResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming OpenAI image response to OpenAI chat format");
body = transformResponseForChat(body as OpenAIImageGenerationResult, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
/**
* Transforms a DALL-E image generation response into a chat response, simply
* embedding the image URL into the chat message as a Markdown image.
*/
function transformResponseForChat(
imageBody: OpenAIImageGenerationResult,
req: Request
): Record<string, any> {
const prompt = imageBody.data[0].revised_prompt ?? req.body.prompt;
const content = imageBody.data
.map((item) => {
const { url, b64_json } = item;
if (b64_json) {
return `![${prompt}](data:image/png;base64,${b64_json})`;
} else {
return `![${prompt}](${url})`;
}
})
.join("\n\n");
return {
id: "dalle-" + req.id,
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: 0,
completion_tokens: req.outputTokens,
total_tokens: req.outputTokens,
},
choices: [
{
message: { role: "assistant", content },
finish_reason: "stop",
index: 0,
},
],
};
}
const openaiImagesProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.openai.com",
changeOrigin: true,
selfHandleResponse: true,
logger,
pathRewrite: {
"^/v1/chat/completions": "/v1/images/generations",
},
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
addKey,
blockZoomerOrigins,
stripHeaders,
finalizeBody,
],
}),
proxyRes: createOnProxyResHandler([openaiImagesResponseHandler]),
error: handleProxyError,
},
}),
});
const openaiImagesRouter = Router();
openaiImagesRouter.get("/v1/models", handleModelRequest);
openaiImagesRouter.post(
"/v1/images/generations",
ipLimiter,
createPreprocessorMiddleware({
inApi: "openai-image",
outApi: "openai-image",
service: "openai",
}),
openaiImagesProxy
);
openaiImagesRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "openai",
outApi: "openai-image",
service: "openai",
}),
openaiImagesProxy
);
export const openaiImage = openaiImagesRouter;

View File

@ -2,61 +2,50 @@ import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
import {
ModelFamily,
OpenAIModelFamily,
getOpenAIModelFamily,
} from "../shared/models";
import { getOpenAIModelFamily, ModelFamily, OpenAIModelFamily } from "../shared/models";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
RequestPreprocessor,
addKey,
addKeyForEmbeddingsRequest,
applyQuotaLimits,
blockZoomerOrigins,
createEmbeddingsPreprocessorMiddleware,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
forceModel,
limitCompletions,
RequestPreprocessor,
stripHeaders,
createOnProxyReqHandler,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
// https://platform.openai.com/docs/models/overview
const KNOWN_MODELS = [
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-4-0314", // EOL 2024-06-13
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-4-32k-0314", // EOL 2024-06-13
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301", // EOL 2024-06-13
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-instruct-0914",
"text-embedding-ada-002",
];
let modelsCache: any = null;
let modelsCacheTime = 0;
function getModelsResponse() {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
// https://platform.openai.com/docs/models/overview
const knownModels = [
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-4-0314", // EOL 2024-06-13
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-4-32k-0314", // EOL 2024-06-13
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301", // EOL 2024-06-13
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-instruct-0914",
"text-embedding-ada-002",
];
export function generateModelList(models = KNOWN_MODELS) {
let available = new Set<OpenAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "openai") continue;
@ -67,7 +56,7 @@ function getModelsResponse() {
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
available = new Set([...available].filter((x) => allowed.has(x)));
const models = knownModels
return models
.map((id) => ({
id,
object: "model",
@ -87,15 +76,14 @@ function getModelsResponse() {
parent: null,
}))
.filter((model) => available.has(getOpenAIModelFamily(model.id)));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
}
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
const result = generateModelList();
modelsCache = { object: "list", data: result };
modelsCacheTime = new Date().getTime();
res.status(200).json(modelsCache);
};
/** Handles some turbo-instruct special cases. */
@ -137,9 +125,8 @@ const openaiResponseHandler: ProxyResHandlerWithBody = async (
body = transformTurboInstructResponse(body);
}
// TODO: Remove once tokenization is stable
if (req.debug) {
body.proxy_tokenizer_debug_info = req.debug;
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);

View File

@ -75,9 +75,8 @@ const palmResponseHandler: ProxyResHandlerWithBody = async (
body = transformPalmResponse(body, req);
}
// TODO: Remove once tokenization is stable
if (req.debug) {
body.proxy_tokenizer_debug_info = req.debug;
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
// TODO: PaLM has no streaming capability which will pose a problem here if

View File

@ -12,11 +12,11 @@
*/
import type { Handler, Request } from "express";
import { keyPool, SupportedModel } from "../shared/key-management";
import { keyPool } from "../shared/key-management";
import {
getClaudeModelFamily,
getGooglePalmModelFamily,
getOpenAIModelFamily,
getOpenAIModelFamily, MODEL_FAMILIES,
ModelFamily,
} from "../shared/models";
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
@ -132,7 +132,7 @@ function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model family.
// Model families are typically separated on cost/rate limit boundaries so
// they should be treated as separate queues.
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
const model = req.body.model ?? "gpt-3.5-turbo";
// Weird special case for AWS because they serve multiple models from
// different vendors, even if currently only one is supported.
@ -145,6 +145,7 @@ function getPartitionForRequest(req: Request): ModelFamily {
return getClaudeModelFamily(model);
case "openai":
case "openai-text":
case "openai-image":
return getOpenAIModelFamily(model);
case "google-palm":
return getGooglePalmModelFamily(model);
@ -207,40 +208,15 @@ export function dequeue(partition: ModelFamily): Request | undefined {
function processQueue() {
// This isn't completely correct, because a key can service multiple models.
// Currently if a key is locked out on one model it will also stop servicing
// the others, because we only track one rate limit per key.
// TODO: `getLockoutPeriod` uses model names instead of model families
// TODO: genericize this it's really ugly
const gpt4TurboLockout = keyPool.getLockoutPeriod("gpt-4-1106");
const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k");
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
const palmLockout = keyPool.getLockoutPeriod("text-bison-001");
const awsClaudeLockout = keyPool.getLockoutPeriod("anthropic.claude-v2");
// the others, because we only track rate limits for the key as a whole.
const reqs: (Request | undefined)[] = [];
if (gpt4TurboLockout === 0) {
reqs.push(dequeue("gpt4-turbo"));
}
if (gpt432kLockout === 0) {
reqs.push(dequeue("gpt4-32k"));
}
if (gpt4Lockout === 0) {
reqs.push(dequeue("gpt4"));
}
if (turboLockout === 0) {
reqs.push(dequeue("turbo"));
}
if (claudeLockout === 0) {
reqs.push(dequeue("claude"));
}
if (palmLockout === 0) {
reqs.push(dequeue("bison"));
}
if (awsClaudeLockout === 0) {
reqs.push(dequeue("aws-claude"));
}
MODEL_FAMILIES.forEach((modelFamily) => {
const lockout = keyPool.getLockoutPeriod(modelFamily);
if (lockout === 0) {
reqs.push(dequeue(modelFamily));
}
});
reqs.filter(Boolean).forEach((req) => {
if (req?.proceed) {

View File

@ -9,8 +9,6 @@ export const SHARED_IP_ADDRESSES = new Set([
"209.97.162.44",
]);
const RATE_LIMIT_ENABLED = Boolean(config.modelRateLimit);
const RATE_LIMIT = Math.max(1, config.modelRateLimit);
const ONE_MINUTE_MS = 60 * 1000;
type Timestamp = number;
@ -22,12 +20,15 @@ const exemptedRequests: Timestamp[] = [];
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
attempt > now - ONE_MINUTE_MS;
const getTryAgainInMs = (ip: string) => {
const getTryAgainInMs = (ip: string, type: "text" | "image") => {
const now = Date.now();
const attempts = lastAttempts.get(ip) || [];
const validAttempts = attempts.filter(isRecentAttempt(now));
if (validAttempts.length >= RATE_LIMIT) {
const limit =
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
if (validAttempts.length >= limit) {
return validAttempts[0] - now + ONE_MINUTE_MS;
} else {
lastAttempts.set(ip, [...validAttempts, now]);
@ -35,12 +36,16 @@ const getTryAgainInMs = (ip: string) => {
}
};
const getStatus = (ip: string) => {
const getStatus = (ip: string, type: "text" | "image") => {
const now = Date.now();
const attempts = lastAttempts.get(ip) || [];
const validAttempts = attempts.filter(isRecentAttempt(now));
const limit =
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
return {
remaining: Math.max(0, RATE_LIMIT - validAttempts.length),
remaining: Math.max(0, limit - validAttempts.length),
reset: validAttempts.length > 0 ? validAttempts[0] + ONE_MINUTE_MS : now,
};
};
@ -69,12 +74,26 @@ setInterval(clearOldExemptions, 10 * 1000);
export const getUniqueIps = () => lastAttempts.size;
/**
* Can be used to manually remove the most recent attempt from an IP address,
* ie. in case a prompt triggered OpenAI's content filter and therefore did not
* result in a generation.
*/
export const refundLastAttempt = (req: Request) => {
const key = req.user?.token || req.risuToken || req.ip;
const attempts = lastAttempts.get(key) || [];
attempts.pop();
}
export const ipLimiter = async (
req: Request,
res: Response,
next: NextFunction
) => {
if (!RATE_LIMIT_ENABLED) return next();
const imageLimit = config.imageModelRateLimit;
const textLimit = config.textModelRateLimit;
if (!textLimit && !imageLimit) return next();
if (req.user?.type === "special") return next();
// Exempts Agnai.chat from IP-based rate limiting because its IPs are shared
@ -90,24 +109,25 @@ export const ipLimiter = async (
return next();
}
const type = req.baseUrl + req.path ? "image" : "text";
const limit = type === "image" ? imageLimit : textLimit;
// If user is authenticated, key rate limiting by their token. Otherwise, key
// rate limiting by their IP address. Mitigates key sharing.
const rateLimitKey = req.user?.token || req.risuToken || req.ip;
const { remaining, reset } = getStatus(rateLimitKey);
res.set("X-RateLimit-Limit", config.modelRateLimit.toString());
const { remaining, reset } = getStatus(rateLimitKey, type);
res.set("X-RateLimit-Limit", limit.toString());
res.set("X-RateLimit-Remaining", remaining.toString());
res.set("X-RateLimit-Reset", reset.toString());
const tryAgainInMs = getTryAgainInMs(rateLimitKey);
const tryAgainInMs = getTryAgainInMs(rateLimitKey, type);
if (tryAgainInMs > 0) {
res.set("Retry-After", tryAgainInMs.toString());
res.status(429).json({
error: {
type: "proxy_rate_limited",
message: `This proxy is rate limited to ${
config.modelRateLimit
} prompts per minute. Please try again in ${Math.ceil(
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil(
tryAgainInMs / 1000
)} seconds.`,
},

View File

@ -2,6 +2,7 @@ import express, { Request, Response, NextFunction } from "express";
import { gatekeeper } from "./gatekeeper";
import { checkRisuToken } from "./check-risu-token";
import { openai } from "./openai";
import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic";
import { googlePalm } from "./palm";
import { aws } from "./aws";
@ -27,6 +28,7 @@ proxyRouter.use((req, _res, next) => {
next();
});
proxyRouter.use("/openai", addV1, openai);
proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-palm", addV1, googlePalm);
proxyRouter.use("/aws/claude", addV1, aws);

View File

@ -1,11 +1,14 @@
import { assertConfigIsValid, config } from "./config";
import { assertConfigIsValid, config, USER_ASSETS_DIR } from "./config";
import "source-map-support/register";
import checkDiskSpace from "check-disk-space";
import express from "express";
import cors from "cors";
import path from "path";
import pinoHttp from "pino-http";
import os from "os";
import childProcess from "child_process";
import { logger } from "./logger";
import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir";
import { keyPool } from "./shared/key-management";
import { adminRouter } from "./admin/routes";
import { proxyRouter } from "./proxy/routes";
@ -58,6 +61,8 @@ app.set("views", [
path.join(__dirname, "shared/views"),
]);
app.use("/user_content", express.static(USER_ASSETS_DIR));
app.get("/health", (_req, res) => res.sendStatus(200));
app.use(cors());
app.use(checkOrigin);
@ -99,13 +104,17 @@ async function start() {
await initTokenizers();
if (config.allowedModelFamilies.includes("dall-e")) {
await setupAssetsDir();
}
if (config.gatekeeper === "user_token") {
await initUserStore();
}
if (config.promptLogging) {
logger.info("Starting prompt logging...");
logQueue.start();
await logQueue.start();
}
logger.info("Starting request queue...");
@ -116,8 +125,12 @@ async function start() {
registerUncaughtExceptionHandler();
});
const diskSpace = await checkDiskSpace(
__dirname.startsWith("/app") ? "/app" : os.homedir()
);
logger.info(
{ build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV },
{ build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV, diskSpace },
"Startup complete."
);
}

View File

@ -0,0 +1,35 @@
type ImageHistory = {
url: string;
prompt: string;
}
const IMAGE_HISTORY_SIZE = 30;
const imageHistory = new Array<ImageHistory>(IMAGE_HISTORY_SIZE);
let imageHistoryIndex = 0;
export function getImageHistory() {
return imageHistory.filter((url) => url);
}
export function addToImageHistory(image: ImageHistory) {
imageHistory[imageHistoryIndex] = image;
imageHistoryIndex = (imageHistoryIndex + 1) % IMAGE_HISTORY_SIZE;
}
export function getLastNImages(n: number) {
const result: ImageHistory[] = [];
let currentIndex = (imageHistoryIndex - 1 + IMAGE_HISTORY_SIZE) % IMAGE_HISTORY_SIZE;
for (let i = 0; i < n; i++) {
// Check if the current index is valid (not undefined).
if (imageHistory[currentIndex]) {
result.unshift(imageHistory[currentIndex]);
}
// Move to the previous item, wrapping around if necessary.
currentIndex = (currentIndex - 1 + IMAGE_HISTORY_SIZE) % IMAGE_HISTORY_SIZE;
}
return result;
}

View File

@ -0,0 +1,75 @@
import axios from "axios";
import { promises as fs } from "fs";
import path from "path";
import { v4 } from "uuid";
import { USER_ASSETS_DIR } from "../../config";
import { logger } from "../../logger";
import { addToImageHistory } from "./image-history";
import sharp from "sharp";
const log = logger.child({ module: "file-storage" });
export type OpenAIImageGenerationResult = {
created: number;
data: {
revised_prompt?: string;
url: string;
b64_json: string;
}[];
};
async function downloadImage(url: string) {
const { data } = await axios.get(url, { responseType: "arraybuffer" });
const buffer = Buffer.from(data, "binary");
const newFilename = `${v4()}.png`;
const filepath = path.join(USER_ASSETS_DIR, newFilename);
await fs.writeFile(filepath, buffer);
return filepath;
}
async function saveB64Image(b64: string) {
const buffer = Buffer.from(b64, "base64");
const newFilename = `${v4()}.png`;
const filepath = path.join(USER_ASSETS_DIR, newFilename);
await fs.writeFile(filepath, buffer);
return filepath;
}
async function createThumbnail(filepath: string) {
const thumbnailPath = filepath.replace(/(\.[\wd_-]+)$/i, "_t.jpg");
await sharp(filepath)
.resize(150, 150, {
fit: "inside",
withoutEnlargement: true,
})
.toFormat("jpeg")
.toFile(thumbnailPath);
return thumbnailPath;
}
/**
* Downloads generated images and mirrors them to the user_content directory.
* Mutates the result object.
*/
export async function mirrorGeneratedImage(
host: string,
prompt: string,
result: OpenAIImageGenerationResult
): Promise<OpenAIImageGenerationResult> {
for (const item of result.data) {
let mirror: string;
if (item.b64_json) {
mirror = await saveB64Image(item.b64_json);
} else {
mirror = await downloadImage(item.url);
}
item.url = `${host}/user_content/${path.basename(mirror)}`;
await createThumbnail(mirror);
addToImageHistory({ url: item.url, prompt });
}
return result;
}

View File

@ -0,0 +1,20 @@
import { promises as fs } from "fs";
import { logger } from "../../logger";
import { USER_ASSETS_DIR } from "../../config";
const log = logger.child({ module: "file-storage" });
export async function setupAssetsDir() {
try {
log.info({ dir: USER_ASSETS_DIR }, "Setting up user assets directory");
await fs.mkdir(USER_ASSETS_DIR, { recursive: true });
const stats = await fs.stat(USER_ASSETS_DIR);
const mode = stats.mode | 0o666;
if (stats.mode !== mode) {
await fs.chmod(USER_ASSETS_DIR, mode);
}
} catch (e) {
log.error(e);
throw new Error("Could not create user assets directory for DALL-E image generation. You may need to update your Dockerfile to `chown` the working directory to user 1000. See the proxy docs for more information.");
}
}

View File

@ -6,14 +6,12 @@ import type { AnthropicModelFamily } from "../../models";
import { AnthropicKeyChecker } from "./checker";
// https://docs.anthropic.com/claude/reference/selecting-a-model
export const ANTHROPIC_SUPPORTED_MODELS = [
"claude-instant-v1",
"claude-instant-v1-100k",
"claude-v1",
"claude-v1-100k",
"claude-2",
] as const;
export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number];
export type AnthropicModel =
| "claude-instant-v1"
| "claude-instant-v1-100k"
| "claude-v1"
| "claude-v1-100k"
| "claude-2";
export type AnthropicKeyUpdate = Omit<
Partial<AnthropicKey>,
@ -180,7 +178,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
key.claudeTokens += tokens;
}
public getLockoutPeriod(_model: AnthropicModel) {
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.

View File

@ -6,12 +6,10 @@ import type { AwsBedrockModelFamily } from "../../models";
import { AwsKeyChecker } from "./checker";
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
export const AWS_BEDROCK_SUPPORTED_MODELS = [
"anthropic.claude-v1",
"anthropic.claude-v2",
"anthropic.claude-instant-v1",
] as const;
export type AwsBedrockModel = (typeof AWS_BEDROCK_SUPPORTED_MODELS)[number];
export type AwsBedrockModel =
| "anthropic.claude-v1"
| "anthropic.claude-v2"
| "anthropic.claude-instant-v1";
type AwsBedrockKeyUsage = {
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
@ -158,7 +156,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
key["aws-claudeTokens"] += tokens;
}
public getLockoutPeriod(_model: AwsBedrockModel) {
public getLockoutPeriod() {
// TODO: same exact behavior for three providers, should be refactored
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.

View File

@ -1,15 +1,17 @@
import { OPENAI_SUPPORTED_MODELS, OpenAIModel } from "./openai/provider";
import {
ANTHROPIC_SUPPORTED_MODELS,
AnthropicModel,
} from "./anthropic/provider";
import { GOOGLE_PALM_SUPPORTED_MODELS, GooglePalmModel } from "./palm/provider";
import { AWS_BEDROCK_SUPPORTED_MODELS, AwsBedrockModel } from "./aws/provider";
import { OpenAIModel } from "./openai/provider";
import { AnthropicModel } from "./anthropic/provider";
import { GooglePalmModel } from "./palm/provider";
import { AwsBedrockModel } from "./aws/provider";
import { KeyPool } from "./key-pool";
import type { ModelFamily } from "../models";
/** The request and response format used by a model's API. */
export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text";
export type APIFormat =
| "openai"
| "anthropic"
| "google-palm"
| "openai-text"
| "openai-image";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
export type Model =
@ -60,23 +62,12 @@ export interface KeyProvider<T extends Key = Key> {
update(hash: string, update: Partial<T>): void;
available(): number;
incrementUsage(hash: string, model: string, tokens: number): void;
getLockoutPeriod(model: Model): number;
getLockoutPeriod(model: ModelFamily): number;
markRateLimited(hash: string): void;
recheck(): void;
}
export const keyPool = new KeyPool();
export const SUPPORTED_MODELS = [
...OPENAI_SUPPORTED_MODELS,
...ANTHROPIC_SUPPORTED_MODELS,
] as const;
export type SupportedModel = (typeof SUPPORTED_MODELS)[number];
export {
OPENAI_SUPPORTED_MODELS,
ANTHROPIC_SUPPORTED_MODELS,
GOOGLE_PALM_SUPPORTED_MODELS,
AWS_BEDROCK_SUPPORTED_MODELS,
};
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider";

View File

@ -9,6 +9,8 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { ModelFamily } from "../models";
import { assertNever } from "../utils";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@ -37,7 +39,7 @@ export class KeyPool {
}
public get(model: Model): Key {
const service = this.getService(model);
const service = this.getServiceForModel(model);
return this.getKeyProvider(service).get(model);
}
@ -67,7 +69,7 @@ export class KeyPool {
public available(model: Model | "all" = "all"): number {
return this.keyProviders.reduce((sum, provider) => {
const includeProvider =
model === "all" || this.getService(model) === provider.service;
model === "all" || this.getServiceForModel(model) === provider.service;
return sum + (includeProvider ? provider.available() : 0);
}, 0);
}
@ -77,9 +79,9 @@ export class KeyPool {
provider.incrementUsage(key.hash, model, tokens);
}
public getLockoutPeriod(model: Model): number {
const service = this.getService(model);
return this.getKeyProvider(service).getLockoutPeriod(model);
public getLockoutPeriod(family: ModelFamily): number {
const service = this.getServiceForModelFamily(family);
return this.getKeyProvider(service).getLockoutPeriod(family);
}
public markRateLimited(key: Key): void {
@ -104,8 +106,12 @@ export class KeyPool {
provider.recheck();
}
private getService(model: Model): LLMService {
if (model.startsWith("gpt") || model.startsWith("text-embedding-ada")) {
private getServiceForModel(model: Model): LLMService {
if (
model.startsWith("gpt") ||
model.startsWith("text-embedding-ada") ||
model.startsWith("dall-e")
) {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
return "openai";
} else if (model.startsWith("claude-")) {
@ -122,6 +128,25 @@ export class KeyPool {
throw new Error(`Unknown service for model '${model}'`);
}
private getServiceForModelFamily(modelFamily: ModelFamily): LLMService {
switch (modelFamily) {
case "gpt4":
case "gpt4-32k":
case "gpt4-turbo":
case "turbo":
case "dall-e":
return "openai";
case "claude":
return "anthropic";
case "bison":
return "google-palm";
case "aws-claude":
return "aws";
default:
assertNever(modelFamily);
}
}
private getKeyProvider(service: LLMService): KeyProvider {
return this.keyProviders.find((provider) => provider.service === service)!;
}

View File

@ -95,10 +95,15 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
const models = data.data;
// const families: OpenAIModelFamily[] = [];
const families = new Set<OpenAIModelFamily>();
models.forEach(({ id }) => families.add(getOpenAIModelFamily(id, "turbo")));
// For now we remove dall-e from the list of provisioned models if only
// dall-e-2 is available.
if (families.has("dall-e") && !models.find(({ id }) => id === "dall-e-3")) {
families.delete("dall-e");
}
// We want to update the key's model families here, but we don't want to
// update its `lastChecked` timestamp because we need to let the liveness
// check run before we can consider the key checked.

View File

@ -15,12 +15,9 @@ export type OpenAIModel =
| "gpt-4"
| "gpt-4-32k"
| "gpt-4-1106"
| "text-embedding-ada-002";
export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-4",
] as const;
| "text-embedding-ada-002"
| "dall-e-2"
| "dall-e-3"
// Flattening model families instead of using a nested object for easier
// cloning.
@ -127,6 +124,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
gpt4Tokens: 0,
"gpt4-32kTokens": 0,
"gpt4-turboTokens": 0,
"dall-eTokens": 0,
gpt4Rpm: 0,
};
this.keys.push(newKey);
@ -284,10 +282,9 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
* Given a model, returns the period until a key will be available to service
* the request, or returns 0 if a key is ready immediately.
*/
public getLockoutPeriod(model: Model = "gpt-4"): number {
const neededFamily = getOpenAIModelFamily(model);
public getLockoutPeriod(family: OpenAIModelFamily): number {
const activeKeys = this.keys.filter(
(key) => !key.isDisabled && key.modelFamilies.includes(neededFamily)
(key) => !key.isDisabled && key.modelFamilies.includes(family)
);
if (activeKeys.length === 0) {
@ -335,6 +332,10 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
key.rateLimitedAt = Date.now();
// DALL-E requests do not send headers telling us when the rate limit will
// be reset so we need to set a fallback value here. Other models will have
// this overwritten by the `updateRateLimits` method.
key.rateLimitRequestsReset = 5000;
}
public incrementUsage(keyHash: string, model: string, tokens: number) {

View File

@ -5,11 +5,7 @@ import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models";
// https://developers.generativeai.google.com/models/language
export const GOOGLE_PALM_SUPPORTED_MODELS = [
"text-bison-001",
// "chat-bison-001", no adjustable safety settings, so it's useless
] as const;
export type GooglePalmModel = (typeof GOOGLE_PALM_SUPPORTED_MODELS)[number];
export type GooglePalmModel = "text-bison-001";
export type GooglePalmKeyUpdate = Omit<
Partial<GooglePalmKey>,
@ -149,7 +145,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
key.bisonTokens += tokens;
}
public getLockoutPeriod(_model: GooglePalmModel) {
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.

View File

@ -1,6 +1,8 @@
import { logger } from "../logger";
// Don't import anything here, this is imported by config.ts
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo";
import pino from "pino";
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e";
export type AnthropicModelFamily = "claude";
export type GooglePalmModelFamily = "bison";
export type AwsBedrockModelFamily = "aws-claude";
@ -17,6 +19,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"gpt4",
"gpt4-32k",
"gpt4-turbo",
"dall-e",
"claude",
"bison",
"aws-claude",
@ -30,8 +33,11 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4$": "gpt4",
"^gpt-3.5-turbo": "turbo",
"^text-embedding-ada-002$": "turbo",
"^dall-e-\\d{1}$": "dall-e",
};
const modelLogger = pino({ level: "debug" }).child({ module: "startup" });
export function getOpenAIModelFamily(
model: string,
defaultFamily: OpenAIModelFamily = "gpt4"
@ -42,14 +48,14 @@ export function getOpenAIModelFamily(
return defaultFamily;
}
export function getClaudeModelFamily(_model: string): ModelFamily {
export function getClaudeModelFamily(model: string): ModelFamily {
if (model.startsWith("anthropic.")) return getAwsBedrockModelFamily(model);
return "claude";
}
export function getGooglePalmModelFamily(model: string): ModelFamily {
if (model.match(/^\w+-bison-\d{3}$/)) return "bison";
const stack = new Error().stack;
logger.warn({ model, stack }, "Unmapped PaLM model family");
modelLogger.warn({ model }, "Could not determine Google PaLM model family");
return "bison";
}

View File

@ -396,7 +396,7 @@ export const init = async (onStop: () => void) => {
await loadIndexSheet(false);
await writeIndexSheet();
} catch (e) {
log.info("Creating new index sheet.");
log.warn(e, "Could not load index sheet. Creating a new one.");
await createIndexSheet();
}
};

View File

@ -17,6 +17,9 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
case "turbo":
cost = 0.000001;
break;
case "dall-e":
cost = 0.00001;
break;
case "aws-claude":
case "claude":
cost = 0.00001102;

View File

@ -79,7 +79,8 @@ export function buildFakeSse(
};
break;
case "google-palm":
throw new Error("PaLM not supported as an inbound API format");
case "openai-image":
throw new Error(`SSE not supported for ${req.inboundApi} requests`);
default:
assertNever(req.inboundApi);
}
@ -92,4 +93,4 @@ export function buildFakeSse(
}
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
}
}

View File

@ -78,3 +78,63 @@ export type OpenAIPromptMessage = {
content: string;
role: string;
};
// Model Resolution Price
// DALL·E 3 1024×1024 $0.040 / image
// 1024×1792, 1792×1024 $0.080 / image
// DALL·E 3 HD 1024×1024 $0.080 / image
// 1024×1792, 1792×1024 $0.120 / image
// DALL·E 2 1024×1024 $0.020 / image
// 512×512 $0.018 / image
// 256×256 $0.016 / image
export const DALLE_TOKENS_PER_DOLLAR = 100000;
/**
* OpenAI image generation with DALL-E doesn't use tokens but everything else
* in the application does. There is a fixed cost for each image generation
* request depending on the model and selected quality/resolution parameters,
* which we convert to tokens at a rate of 100000 tokens per dollar.
*/
export function getOpenAIImageCost(params: {
model: "dall-e-2" | "dall-e-3";
quality: "standard" | "hd";
resolution: "512x512" | "256x256" | "1024x1024" | "1024x1792" | "1792x1024";
n: number | null;
}) {
const { model, quality, resolution, n } = params;
const usd = (() => {
switch (model) {
case "dall-e-2":
switch (resolution) {
case "512x512":
return 0.018;
case "256x256":
return 0.016;
case "1024x1024":
return 0.02;
default:
throw new Error("Invalid resolution");
}
case "dall-e-3":
switch (resolution) {
case "1024x1024":
return quality === "standard" ? 0.04 : 0.08;
case "1024x1792":
case "1792x1024":
return quality === "standard" ? 0.08 : 0.12;
default:
throw new Error("Invalid resolution");
}
default:
throw new Error("Invalid image generation model");
}
})();
const tokens = (n ?? 1) * (usd * DALLE_TOKENS_PER_DOLLAR);
return {
tokenizer: `openai-image cost`,
token_count: Math.ceil(tokens),
};
}

View File

@ -8,6 +8,7 @@ import {
init as initOpenAi,
getTokenCount as getOpenAITokenCount,
OpenAIPromptMessage,
getOpenAIImageCost,
} from "./openai";
import { APIFormat } from "../key-management";
@ -26,6 +27,7 @@ type TokenCountRequest = { req: Request } & (
service: "openai-text" | "anthropic" | "google-palm";
}
| { prompt?: never; completion: string; service: APIFormat }
| { prompt?: never; completion?: never; service: "openai-image" }
);
type TokenCountResult = {
@ -53,6 +55,16 @@ export async function countTokens({
...getOpenAITokenCount(prompt ?? completion, req.body.model),
tokenization_duration_ms: getElapsedMs(time),
};
case "openai-image":
return {
...getOpenAIImageCost({
model: req.body.model,
quality: req.body.quality,
resolution: req.body.size,
n: parseInt(req.body.n, 10) || null,
}),
tokenization_duration_ms: getElapsedMs(time),
};
case "google-palm":
// TODO: Can't find a tokenization library for PaLM. There is an API
// endpoint for it but it adds significant latency to the request.

View File

@ -7,6 +7,7 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
gpt4: z.number().optional().default(0),
"gpt4-32k": z.number().optional().default(0),
"gpt4-turbo": z.number().optional().default(0),
"dall-e": z.number().optional().default(0),
claude: z.number().optional().default(0),
bison: z.number().optional().default(0),
"aws-claude": z.number().optional().default(0),

View File

@ -11,9 +11,17 @@ import admin from "firebase-admin";
import schedule from "node-schedule";
import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config";
import { MODEL_FAMILIES, ModelFamily } from "../models";
import {
getClaudeModelFamily,
getGooglePalmModelFamily,
getOpenAIModelFamily,
MODEL_FAMILIES,
ModelFamily,
} from "../models";
import { logger } from "../../logger";
import { User, UserTokenCounts, UserUpdate } from "./schema";
import { APIFormat } from "../key-management";
import { assertNever } from "../utils";
const log = logger.child({ module: "users" });
@ -22,6 +30,7 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
gpt4: 0,
"gpt4-32k": 0,
"gpt4-turbo": 0,
"dall-e": 0,
claude: 0,
bison: 0,
"aws-claude": 0,
@ -166,11 +175,12 @@ export function incrementPromptCount(token: string) {
export function incrementTokenCount(
token: string,
model: string,
api: APIFormat,
consumption: number
) {
const user = users.get(token);
if (!user) return;
const modelFamily = getModelFamilyForQuotaUsage(model);
const modelFamily = getModelFamilyForQuotaUsage(model, api);
const existing = user.tokenCounts[modelFamily] ?? 0;
user.tokenCounts[modelFamily] = existing + consumption;
usersToFlush.add(token);
@ -181,9 +191,10 @@ export function incrementTokenCount(
* to the user's list of IPs. Returns the user if they exist and are not
* disabled, otherwise returns undefined.
*/
export function authenticate(token: string, ip: string):
{ user?: User; result: "success" | "disabled" | "not_found" | "limited" }
{
export function authenticate(
token: string,
ip: string
): { user?: User; result: "success" | "disabled" | "not_found" | "limited" } {
const user = users.get(token);
if (!user) return { result: "not_found" };
if (user.disabledAt) return { result: "disabled" };
@ -210,16 +221,22 @@ export function authenticate(token: string, ip: string):
return { user, result: "success" };
}
export function hasAvailableQuota(
token: string,
model: string,
requested: number
) {
const user = users.get(token);
export function hasAvailableQuota({
userToken,
model,
api,
requested,
}: {
userToken: string;
model: string;
api: APIFormat;
requested: number;
}) {
const user = users.get(userToken);
if (!user) return false;
if (user.type === "special") return true;
const modelFamily = getModelFamilyForQuotaUsage(model);
const modelFamily = getModelFamilyForQuotaUsage(model, api);
const { tokenCounts, tokenLimits } = user;
const tokenLimit = tokenLimits[modelFamily];
@ -361,30 +378,22 @@ async function flushUsers() {
);
}
// TODO: use key-management/models.ts for family mapping
function getModelFamilyForQuotaUsage(model: string): ModelFamily {
if (model.startsWith("gpt-4-1106")) {
return "gpt4-turbo";
function getModelFamilyForQuotaUsage(
model: string,
api: APIFormat
): ModelFamily {
switch (api) {
case "openai":
case "openai-text":
case "openai-image":
return getOpenAIModelFamily(model);
case "anthropic":
return getClaudeModelFamily(model);
case "google-palm":
return getGooglePalmModelFamily(model);
default:
assertNever(api);
}
if (model.includes("32k")) {
return "gpt4-32k";
}
if (model.startsWith("gpt-4")) {
return "gpt4";
}
if (model.startsWith("gpt-3.5")) {
return "turbo";
}
if (model.includes("bison")) {
return "bison";
}
if (model.startsWith("claude")) {
return "claude";
}
if (model.startsWith("anthropic.claude")) {
return "aws-claude";
}
throw new Error(`Unknown quota model family for model ${model}`);
}
function getRefreshCrontab() {

View File

@ -24,8 +24,7 @@ declare global {
heartbeatInterval?: NodeJS.Timeout;
promptTokens?: number;
outputTokens?: number;
// TODO: remove later
debug: Record<string, any>;
tokenizerInfo: Record<string, any>;
signedRequest: HttpRequest;
}
}