diff --git a/.env.example b/.env.example
index 8357a18..a1e9cd2 100644
--- a/.env.example
+++ b/.env.example
@@ -11,8 +11,10 @@
# The title displayed on the info page.
# SERVER_TITLE=Coom Tunnel
-# Model requests allowed per minute per user.
-# MODEL_RATE_LIMIT=4
+# Text model requests allowed per minute per user.
+# TEXT_MODEL_RATE_LIMIT=4
+# Image model requests allowed per minute per user.
+# IMAGE_MODEL_RATE_LIMIT=2
# Max number of context tokens a user can request at once.
# Increase this if your proxy allow GPT 32k or 128k context
@@ -31,10 +33,11 @@
# CHECK_KEYS=true
# Which model types users are allowed to access.
-# If you want to restrict access to certain models, uncomment the line below and list only the models you want to allow,
-# separated by commas. By default, all models are allowed. The following model families are recognized:
-# turbo | gpt4 | gpt4-32k | gpt4-turbo | claude | bison | aws-claude
-# ALLOWED_MODEL_FAMILIES=turbo,gpt4-turbo,aws-claude
+# The following model families are recognized:
+# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude
+# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
+# generation, uncomment the line below and add 'dall-e' to the list.
+# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude
# URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com
@@ -82,10 +85,18 @@
# ALLOW_NICKNAME_CHANGES=true
# Default token quotas for each model family. (0 for unlimited)
+# DALL-E "tokens" are counted at a rate of 100000 tokens per US$1.00 generated,
+# which is similar to the cost of GPT-4 Turbo.
+# DALL-E 3 costs around US$0.10 per image (10000 tokens).
+# See `docs/dall-e-configuration.md` for more information.
# TOKEN_QUOTA_TURBO=0
# TOKEN_QUOTA_GPT4=0
# TOKEN_QUOTA_GPT4_32K=0
+# TOKEN_QUOTA_GPT4_TURBO=0
+# TOKEN_QUOTA_DALL_E=0
# TOKEN_QUOTA_CLAUDE=0
+# TOKEN_QUOTA_BISON=0
+# TOKEN_QUOTA_AWS_CLAUDE=0
# How often to refresh token quotas. (hourly | daily)
# Leave unset to never automatically refresh quotas.
diff --git a/data/.gitignore b/data/.gitignore
new file mode 100644
index 0000000..377ccd3
--- /dev/null
+++ b/data/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitkeep
diff --git a/data/user-files/.gitkeep b/data/user-files/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/docker/huggingface/Dockerfile b/docker/huggingface/Dockerfile
index eef259f..7ab1c36 100644
--- a/docker/huggingface/Dockerfile
+++ b/docker/huggingface/Dockerfile
@@ -3,6 +3,8 @@ RUN apt-get update && \
apt-get install -y git
RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
WORKDIR /app
+RUN chown -R 1000:1000 /app
+USER 1000
RUN npm install
COPY Dockerfile greeting.md* .env* ./
RUN npm run build
diff --git a/docker/render/Dockerfile b/docker/render/Dockerfile
index 67731c1..3bc7ed8 100644
--- a/docker/render/Dockerfile
+++ b/docker/render/Dockerfile
@@ -17,6 +17,8 @@ ARG GREETING_URL
RUN if [ -n "$GREETING_URL" ]; then \
curl -sL "$GREETING_URL" > greeting.md; \
fi
+RUN chown -R 1000:1000 /app
+USER 1000
COPY package*.json greeting.md* ./
RUN npm install
COPY . .
diff --git a/docs/dall-e-configuration.md b/docs/dall-e-configuration.md
new file mode 100644
index 0000000..3377e32
--- /dev/null
+++ b/docs/dall-e-configuration.md
@@ -0,0 +1,71 @@
+# Configuring the proxy for DALL-E
+
+The proxy supports DALL-E 2 and DALL-E 3 image generation via the `/proxy/openai-images` endpoint. By default it is disabled as it is somewhat expensive and potentially more open to abuse than text generation.
+
+- [Updating your Dockerfile](#updating-your-dockerfile)
+- [Enabling DALL-E](#enabling-dall-e)
+- [Setting quotas](#setting-quotas)
+- [Rate limiting](#rate-limiting)
+
+## Updating your Dockerfile
+If you are using a previous version of the Dockerfile supplied with the proxy, it doesn't have the necessary permissions to let the proxy save temporary files.
+
+You can replace the entire thing with the new Dockerfile at [./docker/huggingface/Dockerfile](../docker/huggingface/Dockerfile) (or the equivalent for Render deployments).
+
+You can also modify your existing Dockerfile; just add the following lines after the `WORKDIR` line:
+
+```Dockerfile
+# Existing
+RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
+WORKDIR /app
+
+# Take ownership of the app directory and switch to the non-root user
+RUN chown -R 1000:1000 /app
+USER 1000
+
+# Existing
+RUN npm install
+```
+
+## Enabling DALL-E
+Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL-E. For example:
+
+```
+# GPT3.5 Turbo, GPT-4, GPT-4 Turbo, and DALL-E
+ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
+
+# All models as of this writing
+ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e
+```
+
+Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
+
+## Setting quotas
+DALL-E doesn't bill by token like text generation models. Instead there is a fixed cost per image generated, depending on the model, image size, and selected quality.
+
+The proxy still uses tokens to set quotas for users. The cost for each generated image will be converted to "tokens" at a rate of 100000 tokens per US$1.00. This works out to a similar cost-per-token as GPT-4 Turbo, so you can use similar token quotas for both.
+
+Use `TOKEN_QUOTA_DALL_E` to set the default quota for image generation. Otherwise it works the same as token quotas for other models.
+
+```
+# ~50 standard DALL-E images per refresh period, or US$2.00
+TOKEN_QUOTA_DALL_E=200000
+```
+
+Refer to [https://openai.com/pricing](https://openai.com/pricing) for the latest pricing information. As of this writing, the cheapest DALL-E 3 image costs $0.04 per generation, which works out to 4000 tokens. Higher resolution and quality settings can cost up to $0.12 per image, or 12000 tokens.
+
+## Rate limiting
+The old `MODEL_RATE_LIMIT` setting has been split into `TEXT_MODEL_RATE_LIMIT` and `IMAGE_MODEL_RATE_LIMIT`. Whatever value you previously set for `MODEL_RATE_LIMIT` will be used for text models.
+
+If you don't specify a `IMAGE_MODEL_RATE_LIMIT`, it defaults to half of the `TEXT_MODEL_RATE_LIMIT`, to a minimum of 1 image per minute.
+
+```
+# 4 text generations per minute, 2 images per minute
+TEXT_MODEL_RATE_LIMIT=4
+IMAGE_MODEL_RATE_LIMIT=2
+```
+
+If a prompt is filtered by OpenAI's content filter, it won't count towards the rate limit.
+
+## Hiding recent images
+By default, the proxy shows the last 12 recently generated images by users. You can hide this section by setting `SHOW_RECENT_IMAGES` to `false`.
diff --git a/docs/deploy-huggingface.md b/docs/deploy-huggingface.md
index 4a112c0..968b1c9 100644
--- a/docs/deploy-huggingface.md
+++ b/docs/deploy-huggingface.md
@@ -25,6 +25,8 @@ RUN apt-get update && \
apt-get install -y git
RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
WORKDIR /app
+RUN chown -R 1000:1000 /app
+USER 1000
RUN npm install
COPY Dockerfile greeting.md* .env* ./
RUN npm run build
diff --git a/package-lock.json b/package-lock.json
index 72a8979..d87d555 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -15,6 +15,7 @@
"@smithy/signature-v4": "^2.0.10",
"@smithy/types": "^2.3.4",
"axios": "^1.3.5",
+ "check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
"copyfiles": "^2.4.1",
"cors": "^2.8.5",
@@ -33,6 +34,7 @@
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"sanitize-html": "^2.11.0",
+ "sharp": "^0.32.6",
"showdown": "^2.1.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",
@@ -1373,15 +1375,20 @@
}
},
"node_modules/axios": {
- "version": "1.3.5",
- "resolved": "https://registry.npmjs.org/axios/-/axios-1.3.5.tgz",
- "integrity": "sha512-glL/PvG/E+xCWwV8S6nCHcrfg1exGx7vxyUIivIA1iL7BIh6bePylCfVHwp6k13ao7SATxB6imau2kqY+I67kw==",
+ "version": "1.6.1",
+ "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
+ "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
"dependencies": {
"follow-redirects": "^1.15.0",
"form-data": "^4.0.0",
"proxy-from-env": "^1.1.0"
}
},
+ "node_modules/b4a": {
+ "version": "1.6.4",
+ "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz",
+ "integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw=="
+ },
"node_modules/balanced-match": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
@@ -1423,6 +1430,52 @@
"node": ">=8"
}
},
+ "node_modules/bl": {
+ "version": "4.1.0",
+ "resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz",
+ "integrity": "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==",
+ "dependencies": {
+ "buffer": "^5.5.0",
+ "inherits": "^2.0.4",
+ "readable-stream": "^3.4.0"
+ }
+ },
+ "node_modules/bl/node_modules/buffer": {
+ "version": "5.7.1",
+ "resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz",
+ "integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==",
+ "funding": [
+ {
+ "type": "github",
+ "url": "https://github.com/sponsors/feross"
+ },
+ {
+ "type": "patreon",
+ "url": "https://www.patreon.com/feross"
+ },
+ {
+ "type": "consulting",
+ "url": "https://feross.org/support"
+ }
+ ],
+ "dependencies": {
+ "base64-js": "^1.3.1",
+ "ieee754": "^1.1.13"
+ }
+ },
+ "node_modules/bl/node_modules/readable-stream": {
+ "version": "3.6.2",
+ "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz",
+ "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==",
+ "dependencies": {
+ "inherits": "^2.0.3",
+ "string_decoder": "^1.1.1",
+ "util-deprecate": "^1.0.1"
+ },
+ "engines": {
+ "node": ">= 6"
+ }
+ },
"node_modules/bluebird": {
"version": "3.7.2",
"resolved": "https://registry.npmjs.org/bluebird/-/bluebird-3.7.2.tgz",
@@ -1582,6 +1635,14 @@
"node": ">=8"
}
},
+ "node_modules/check-disk-space": {
+ "version": "3.4.0",
+ "resolved": "https://registry.npmjs.org/check-disk-space/-/check-disk-space-3.4.0.tgz",
+ "integrity": "sha512-drVkSqfwA+TvuEhFipiR1OC9boEGZL5RrWvVsOthdcvQNXyCCuKkEiTOTXZ7qxSf/GLwq4GvzfrQD/Wz325hgw==",
+ "engines": {
+ "node": ">=16"
+ }
+ },
"node_modules/chokidar": {
"version": "3.5.3",
"resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz",
@@ -1609,6 +1670,11 @@
"fsevents": "~2.3.2"
}
},
+ "node_modules/chownr": {
+ "version": "1.1.4",
+ "resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz",
+ "integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg=="
+ },
"node_modules/cliui": {
"version": "8.0.1",
"resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz",
@@ -1623,6 +1689,18 @@
"node": ">=12"
}
},
+ "node_modules/color": {
+ "version": "4.2.3",
+ "resolved": "https://registry.npmjs.org/color/-/color-4.2.3.tgz",
+ "integrity": "sha512-1rXeuUUiGGrykh+CeBdu5Ie7OJwinCgQY0bc7GCRxy5xVHy+moaqkpL/jqQq0MtQOeYcrqEz4abc5f0KtU7W4A==",
+ "dependencies": {
+ "color-convert": "^2.0.1",
+ "color-string": "^1.9.0"
+ },
+ "engines": {
+ "node": ">=12.5.0"
+ }
+ },
"node_modules/color-convert": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz",
@@ -1639,6 +1717,15 @@
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz",
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA=="
},
+ "node_modules/color-string": {
+ "version": "1.9.1",
+ "resolved": "https://registry.npmjs.org/color-string/-/color-string-1.9.1.tgz",
+ "integrity": "sha512-shrVawQFojnZv6xM40anx4CkoDP+fZsw/ZerEMsW/pyzsRbElpsL/DBVW7q3ExxwusdNXI3lXpuhEZkzs8p5Eg==",
+ "dependencies": {
+ "color-name": "^1.0.0",
+ "simple-swizzle": "^0.2.2"
+ }
+ },
"node_modules/colorette": {
"version": "2.0.20",
"resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz",
@@ -2000,6 +2087,28 @@
"ms": "2.0.0"
}
},
+ "node_modules/decompress-response": {
+ "version": "6.0.0",
+ "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-6.0.0.tgz",
+ "integrity": "sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==",
+ "dependencies": {
+ "mimic-response": "^3.1.0"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/deep-extend": {
+ "version": "0.6.0",
+ "resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz",
+ "integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==",
+ "engines": {
+ "node": ">=4.0.0"
+ }
+ },
"node_modules/deep-is": {
"version": "0.1.4",
"resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz",
@@ -2039,6 +2148,14 @@
"npm": "1.2.8000 || >= 1.4.16"
}
},
+ "node_modules/detect-libc": {
+ "version": "2.0.2",
+ "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.2.tgz",
+ "integrity": "sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==",
+ "engines": {
+ "node": ">=8"
+ }
+ },
"node_modules/diff": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
@@ -2188,7 +2305,6 @@
"version": "1.4.4",
"resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz",
"integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==",
- "devOptional": true,
"dependencies": {
"once": "^1.4.0"
}
@@ -2473,6 +2589,14 @@
"node": ">=0.8.x"
}
},
+ "node_modules/expand-template": {
+ "version": "2.0.3",
+ "resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz",
+ "integrity": "sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==",
+ "engines": {
+ "node": ">=6"
+ }
+ },
"node_modules/express": {
"version": "4.18.2",
"resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz",
@@ -2557,6 +2681,11 @@
"integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==",
"optional": true
},
+ "node_modules/fast-fifo": {
+ "version": "1.3.2",
+ "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz",
+ "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ=="
+ },
"node_modules/fast-levenshtein": {
"version": "2.0.6",
"resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz",
@@ -2718,6 +2847,11 @@
"node": ">= 0.6"
}
},
+ "node_modules/fs-constants": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz",
+ "integrity": "sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow=="
+ },
"node_modules/fs.realpath": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz",
@@ -2795,6 +2929,11 @@
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/github-from-package": {
+ "version": "0.0.0",
+ "resolved": "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz",
+ "integrity": "sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw=="
+ },
"node_modules/glob": {
"version": "8.1.0",
"resolved": "https://registry.npmjs.org/glob/-/glob-8.1.0.tgz",
@@ -3246,6 +3385,11 @@
"resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz",
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
},
+ "node_modules/ini": {
+ "version": "1.3.8",
+ "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
+ "integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
+ },
"node_modules/ipaddr.js": {
"version": "1.9.1",
"resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz",
@@ -3254,6 +3398,11 @@
"node": ">= 0.10"
}
},
+ "node_modules/is-arrayish": {
+ "version": "0.3.2",
+ "resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.3.2.tgz",
+ "integrity": "sha512-eVRqCvVlZbuw3GrM63ovNSNAeA1K16kaR/LRY/92w0zxQ5/1YzwblUX652i4Xs9RwAGjW9d9y6X88t8OaAJfWQ=="
+ },
"node_modules/is-binary-path": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz",
@@ -3780,6 +3929,17 @@
"node": ">= 0.6"
}
},
+ "node_modules/mimic-response": {
+ "version": "3.1.0",
+ "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-3.1.0.tgz",
+ "integrity": "sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==",
+ "engines": {
+ "node": ">=10"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
"node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
@@ -3810,6 +3970,11 @@
"node": ">=10"
}
},
+ "node_modules/mkdirp-classic": {
+ "version": "0.5.3",
+ "resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz",
+ "integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A=="
+ },
"node_modules/ms": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz",
@@ -3860,6 +4025,11 @@
"node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1"
}
},
+ "node_modules/napi-build-utils": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/napi-build-utils/-/napi-build-utils-1.0.2.tgz",
+ "integrity": "sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg=="
+ },
"node_modules/negotiator": {
"version": "0.6.3",
"resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz",
@@ -3868,6 +4038,22 @@
"node": ">= 0.6"
}
},
+ "node_modules/node-abi": {
+ "version": "3.51.0",
+ "resolved": "https://registry.npmjs.org/node-abi/-/node-abi-3.51.0.tgz",
+ "integrity": "sha512-SQkEP4hmNWjlniS5zdnfIXTk1x7Ome85RDzHlTbBtzE97Gfwz/Ipw4v/Ryk20DWIy3yCNVLVlGKApCnmvYoJbA==",
+ "dependencies": {
+ "semver": "^7.3.5"
+ },
+ "engines": {
+ "node": ">=10"
+ }
+ },
+ "node_modules/node-addon-api": {
+ "version": "6.1.0",
+ "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz",
+ "integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA=="
+ },
"node_modules/node-fetch": {
"version": "2.6.9",
"resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.9.tgz",
@@ -4217,6 +4403,70 @@
"node": "^10 || ^12 || >=14"
}
},
+ "node_modules/prebuild-install": {
+ "version": "7.1.1",
+ "resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.1.tgz",
+ "integrity": "sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==",
+ "dependencies": {
+ "detect-libc": "^2.0.0",
+ "expand-template": "^2.0.3",
+ "github-from-package": "0.0.0",
+ "minimist": "^1.2.3",
+ "mkdirp-classic": "^0.5.3",
+ "napi-build-utils": "^1.0.1",
+ "node-abi": "^3.3.0",
+ "pump": "^3.0.0",
+ "rc": "^1.2.7",
+ "simple-get": "^4.0.0",
+ "tar-fs": "^2.0.0",
+ "tunnel-agent": "^0.6.0"
+ },
+ "bin": {
+ "prebuild-install": "bin.js"
+ },
+ "engines": {
+ "node": ">=10"
+ }
+ },
+ "node_modules/prebuild-install/node_modules/readable-stream": {
+ "version": "3.6.2",
+ "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz",
+ "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==",
+ "dependencies": {
+ "inherits": "^2.0.3",
+ "string_decoder": "^1.1.1",
+ "util-deprecate": "^1.0.1"
+ },
+ "engines": {
+ "node": ">= 6"
+ }
+ },
+ "node_modules/prebuild-install/node_modules/tar-fs": {
+ "version": "2.1.1",
+ "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz",
+ "integrity": "sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==",
+ "dependencies": {
+ "chownr": "^1.1.1",
+ "mkdirp-classic": "^0.5.2",
+ "pump": "^3.0.0",
+ "tar-stream": "^2.1.4"
+ }
+ },
+ "node_modules/prebuild-install/node_modules/tar-stream": {
+ "version": "2.2.0",
+ "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-2.2.0.tgz",
+ "integrity": "sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==",
+ "dependencies": {
+ "bl": "^4.0.3",
+ "end-of-stream": "^1.4.1",
+ "fs-constants": "^1.0.0",
+ "inherits": "^2.0.3",
+ "readable-stream": "^3.1.1"
+ },
+ "engines": {
+ "node": ">=6"
+ }
+ },
"node_modules/prettier": {
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz",
@@ -4352,7 +4602,6 @@
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz",
"integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==",
- "dev": true,
"dependencies": {
"end-of-stream": "^1.1.0",
"once": "^1.3.1"
@@ -4372,6 +4621,11 @@
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/queue-tick": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz",
+ "integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag=="
+ },
"node_modules/quick-format-unescaped": {
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/quick-format-unescaped/-/quick-format-unescaped-4.0.4.tgz",
@@ -4407,6 +4661,28 @@
"node": ">= 0.8"
}
},
+ "node_modules/rc": {
+ "version": "1.2.8",
+ "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz",
+ "integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==",
+ "dependencies": {
+ "deep-extend": "^0.6.0",
+ "ini": "~1.3.0",
+ "minimist": "^1.2.0",
+ "strip-json-comments": "~2.0.1"
+ },
+ "bin": {
+ "rc": "cli.js"
+ }
+ },
+ "node_modules/rc/node_modules/strip-json-comments": {
+ "version": "2.0.1",
+ "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz",
+ "integrity": "sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==",
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
"node_modules/readable-stream": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.3.0.tgz",
@@ -4615,9 +4891,9 @@
"dev": true
},
"node_modules/semver": {
- "version": "7.5.3",
- "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.3.tgz",
- "integrity": "sha512-QBlUtyVk/5EeHbi7X0fw6liDZc7BBmEaSYn01fMU1OUYbf6GPsbTtd8WmnqbI20SeycoHSeiybkE/q1Q+qlThQ==",
+ "version": "7.5.4",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
+ "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
"dependencies": {
"lru-cache": "^6.0.0"
},
@@ -4675,6 +4951,28 @@
"resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz",
"integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw=="
},
+ "node_modules/sharp": {
+ "version": "0.32.6",
+ "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz",
+ "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==",
+ "hasInstallScript": true,
+ "dependencies": {
+ "color": "^4.2.3",
+ "detect-libc": "^2.0.2",
+ "node-addon-api": "^6.1.0",
+ "prebuild-install": "^7.1.1",
+ "semver": "^7.5.4",
+ "simple-get": "^4.0.1",
+ "tar-fs": "^3.0.4",
+ "tunnel-agent": "^0.6.0"
+ },
+ "engines": {
+ "node": ">=14.15.0"
+ },
+ "funding": {
+ "url": "https://opencollective.com/libvips"
+ }
+ },
"node_modules/shell-quote": {
"version": "1.8.1",
"resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz",
@@ -4712,6 +5010,57 @@
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/simple-concat": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/simple-concat/-/simple-concat-1.0.1.tgz",
+ "integrity": "sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==",
+ "funding": [
+ {
+ "type": "github",
+ "url": "https://github.com/sponsors/feross"
+ },
+ {
+ "type": "patreon",
+ "url": "https://www.patreon.com/feross"
+ },
+ {
+ "type": "consulting",
+ "url": "https://feross.org/support"
+ }
+ ]
+ },
+ "node_modules/simple-get": {
+ "version": "4.0.1",
+ "resolved": "https://registry.npmjs.org/simple-get/-/simple-get-4.0.1.tgz",
+ "integrity": "sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==",
+ "funding": [
+ {
+ "type": "github",
+ "url": "https://github.com/sponsors/feross"
+ },
+ {
+ "type": "patreon",
+ "url": "https://www.patreon.com/feross"
+ },
+ {
+ "type": "consulting",
+ "url": "https://feross.org/support"
+ }
+ ],
+ "dependencies": {
+ "decompress-response": "^6.0.0",
+ "once": "^1.3.1",
+ "simple-concat": "^1.0.0"
+ }
+ },
+ "node_modules/simple-swizzle": {
+ "version": "0.2.2",
+ "resolved": "https://registry.npmjs.org/simple-swizzle/-/simple-swizzle-0.2.2.tgz",
+ "integrity": "sha512-JA//kQgZtbuY83m+xT+tXJkmJncGMTFT+C+g2h2R9uxkYIrE2yy9sgmcLhCnw57/WSD+Eh3J97FPEDFnbXnDUg==",
+ "dependencies": {
+ "is-arrayish": "^0.3.1"
+ }
+ },
"node_modules/simple-update-notifier": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/simple-update-notifier/-/simple-update-notifier-2.0.0.tgz",
@@ -4809,11 +5158,19 @@
"node": ">=10.0.0"
}
},
+ "node_modules/streamx": {
+ "version": "2.15.4",
+ "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.4.tgz",
+ "integrity": "sha512-uSXKl88bibiUCQ1eMpItRljCzDENcDx18rsfDmV79r0e/ThfrAwxG4Y2FarQZ2G4/21xcOKmFFd1Hue+ZIDwHw==",
+ "dependencies": {
+ "fast-fifo": "^1.1.0",
+ "queue-tick": "^1.0.1"
+ }
+ },
"node_modules/string_decoder": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz",
"integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==",
- "devOptional": true,
"dependencies": {
"safe-buffer": "~5.2.0"
}
@@ -4872,6 +5229,26 @@
"node": ">=4"
}
},
+ "node_modules/tar-fs": {
+ "version": "3.0.4",
+ "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz",
+ "integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==",
+ "dependencies": {
+ "mkdirp-classic": "^0.5.2",
+ "pump": "^3.0.0",
+ "tar-stream": "^3.1.5"
+ }
+ },
+ "node_modules/tar-stream": {
+ "version": "3.1.6",
+ "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz",
+ "integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==",
+ "dependencies": {
+ "b4a": "^1.6.4",
+ "fast-fifo": "^1.2.0",
+ "streamx": "^2.15.0"
+ }
+ },
"node_modules/teeny-request": {
"version": "8.0.3",
"resolved": "https://registry.npmjs.org/teeny-request/-/teeny-request-8.0.3.tgz",
@@ -5047,6 +5424,17 @@
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q=="
},
+ "node_modules/tunnel-agent": {
+ "version": "0.6.0",
+ "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz",
+ "integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==",
+ "dependencies": {
+ "safe-buffer": "^5.0.1"
+ },
+ "engines": {
+ "node": "*"
+ }
+ },
"node_modules/type-is": {
"version": "1.6.18",
"resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz",
diff --git a/package.json b/package.json
index e765c18..6fe95f7 100644
--- a/package.json
+++ b/package.json
@@ -23,6 +23,7 @@
"@smithy/signature-v4": "^2.0.10",
"@smithy/types": "^2.3.4",
"axios": "^1.3.5",
+ "check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
"copyfiles": "^2.4.1",
"cors": "^2.8.5",
@@ -41,6 +42,7 @@
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"sanitize-html": "^2.11.0",
+ "sharp": "^0.32.6",
"showdown": "^2.1.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",
diff --git a/src/config.ts b/src/config.ts
index 72f555b..916d78d 100644
--- a/src/config.ts
+++ b/src/config.ts
@@ -1,12 +1,17 @@
import dotenv from "dotenv";
import type firebase from "firebase-admin";
+import path from "path";
import pino from "pino";
import type { ModelFamily } from "./shared/models";
+import { MODEL_FAMILIES } from "./shared/models";
dotenv.config();
const startupLogger = pino({ level: "debug" }).child({ module: "startup" });
const isDev = process.env.NODE_ENV !== "production";
+export const DATA_DIR = path.join(__dirname, "..", "data");
+export const USER_ASSETS_DIR = path.join(DATA_DIR, "user-files");
+
type Config = {
/** The port the proxy server will listen on. */
port: number;
@@ -75,8 +80,10 @@ type Config = {
* `maxIpsPerUser` limit, or if only connections from new IPs are be rejected.
*/
maxIpsAutoBan: boolean;
- /** Per-IP limit for requests per minute to OpenAI's completions endpoint. */
- modelRateLimit: number;
+ /** Per-IP limit for requests per minute to text and chat models. */
+ textModelRateLimit: number;
+ /** Per-IP limit for requests per minute to image generation models. */
+ imageModelRateLimit: number;
/**
* For OpenAI, the maximum number of context tokens (prompt + max output) a
* user can request before their request is rejected.
@@ -157,6 +164,8 @@ type Config = {
quotaRefreshPeriod?: "hourly" | "daily" | string;
/** Whether to allow users to change their own nicknames via the UI. */
allowNicknameChanges: boolean;
+ /** Whether to show recent DALL-E image generations on the homepage. */
+ showRecentImages: boolean;
/**
* If true, cookies will be set without the `Secure` attribute, allowing
* the admin UI to used over HTTP.
@@ -180,7 +189,8 @@ export const config: Config = {
maxIpsAutoBan: getEnvWithDefault("MAX_IPS_AUTO_BAN", true),
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
- modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
+ textModelRateLimit: getEnvWithDefault("TEXT_MODEL_RATE_LIMIT", 4),
+ imageModelRateLimit: getEnvWithDefault("IMAGE_MODEL_RATE_LIMIT", 4),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 16384),
maxContextTokensAnthropic: getEnvWithDefault(
"MAX_CONTEXT_TOKENS_ANTHROPIC",
@@ -225,17 +235,19 @@ export const config: Config = {
"You must be over the age of majority in your country to use this service."
),
blockRedirect: getEnvWithDefault("BLOCK_REDIRECT", "https://www.9gag.com"),
- tokenQuota: {
- turbo: getEnvWithDefault("TOKEN_QUOTA_TURBO", 0),
- gpt4: getEnvWithDefault("TOKEN_QUOTA_GPT4", 0),
- "gpt4-32k": getEnvWithDefault("TOKEN_QUOTA_GPT4_32K", 0),
- "gpt4-turbo": getEnvWithDefault("TOKEN_QUOTA_GPT4_TURBO", 0),
- claude: getEnvWithDefault("TOKEN_QUOTA_CLAUDE", 0),
- bison: getEnvWithDefault("TOKEN_QUOTA_BISON", 0),
- "aws-claude": getEnvWithDefault("TOKEN_QUOTA_AWS_CLAUDE", 0),
- },
+ tokenQuota: MODEL_FAMILIES.reduce(
+ (acc, family: ModelFamily) => {
+ acc[family] = getEnvWithDefault(
+ `TOKEN_QUOTA_${family.toUpperCase().replace(/-/g, "_")}`,
+ 0
+ ) as number;
+ return acc;
+ },
+ {} as { [key in ModelFamily]: number }
+ ),
quotaRefreshPeriod: getEnvWithDefault("QUOTA_REFRESH_PERIOD", undefined),
allowNicknameChanges: getEnvWithDefault("ALLOW_NICKNAME_CHANGES", true),
+ showRecentImages: getEnvWithDefault("SHOW_RECENT_IMAGES", true),
useInsecureCookies: getEnvWithDefault("USE_INSECURE_COOKIES", isDev),
} as const;
@@ -252,6 +264,19 @@ function generateCookieSecret() {
export const COOKIE_SECRET = generateCookieSecret();
export async function assertConfigIsValid() {
+ if (process.env.MODEL_RATE_LIMIT !== undefined) {
+ const limit =
+ parseInt(process.env.MODEL_RATE_LIMIT, 10) || config.textModelRateLimit;
+
+ config.textModelRateLimit = limit;
+ config.imageModelRateLimit = Math.max(Math.floor(limit / 2), 1);
+
+ startupLogger.warn(
+ { textLimit: limit, imageLimit: config.imageModelRateLimit },
+ "MODEL_RATE_LIMIT is deprecated. Use TEXT_MODEL_RATE_LIMIT and IMAGE_MODEL_RATE_LIMIT instead."
+ );
+ }
+
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
throw new Error(
`Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.`
@@ -332,6 +357,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"blockMessage",
"blockRedirect",
"allowNicknameChanges",
+ "showRecentImages",
"useInsecureCookies",
];
@@ -428,5 +454,5 @@ function parseCsv(val: string): string[] {
const regex = /(".*?"|[^",]+)(?=\s*,|\s*$)/g;
const matches = val.match(regex) || [];
- return matches.map(item => item.replace(/^"|"$/g, '').trim());
+ return matches.map((item) => item.replace(/^"|"$/g, "").trim());
}
diff --git a/src/info-page.ts b/src/info-page.ts
index caa38e4..07602a5 100644
--- a/src/info-page.ts
+++ b/src/info-page.ts
@@ -14,6 +14,7 @@ import { getUniqueIps } from "./proxy/rate-limit";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
import { assertNever } from "./shared/utils";
+import { getLastNImages } from "./shared/file-storage/image-history";
const INFO_PAGE_TTL = 2000;
let infoPageHtml: string | undefined;
@@ -94,6 +95,8 @@ function cacheInfoPageHtml(baseUrl: string) {
const tokens = serviceStats.get("tokens") || 0;
const tokenCost = serviceStats.get("tokenCost") || 0;
+ const allowDalle = config.allowedModelFamilies.includes("dall-e");
+
const info = {
uptime: Math.floor(process.uptime()),
endpoints: {
@@ -101,13 +104,16 @@ function cacheInfoPageHtml(baseUrl: string) {
...(openaiKeys
? { ["openai2"]: baseUrl + "/proxy/openai/turbo-instruct" }
: {}),
+ ...(openaiKeys && allowDalle
+ ? { ["openai-image"]: baseUrl + "/proxy/openai-image" }
+ : {}),
...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}),
...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}),
...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}),
},
proompts,
tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`,
- ...(config.modelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
+ ...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
openaiKeys,
anthropicKeys,
palmKeys,
@@ -287,7 +293,6 @@ function getOpenAIInfo() {
// Don't show trial/revoked keys for non-turbo families.
// Generally those stats only make sense for the lowest-tier model.
if (f !== "turbo") {
- console.log("deleting", f);
delete info[f]!.trialKeys;
delete info[f]!.revokedKeys;
}
@@ -457,6 +462,9 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t
if (customGreeting) {
infoBody += `\n## Server Greeting\n${customGreeting}`;
}
+
+ infoBody += buildRecentImageSection();
+
return converter.makeHtml(infoBody);
}
@@ -499,6 +507,43 @@ function getServerTitle() {
return "OAI Reverse Proxy";
}
+function buildRecentImageSection() {
+ if (
+ !config.allowedModelFamilies.includes("dall-e") ||
+ !config.showRecentImages
+ ) {
+ return "";
+ }
+
+ let html = `
Recent DALL-E Generations
`;
+ const recentImages = getLastNImages(12).reverse();
+ if (recentImages.length === 0) {
+ html += `No images yet.
`;
+ return html;
+ }
+
+ html += ``;
+ for (const { url, prompt } of recentImages) {
+ const thumbUrl = url.replace(/\.png$/, "_t.jpg");
+ const escapedPrompt = escapeHtml(prompt);
+ html += `
+
+
`;
+ }
+ html += `
`;
+
+ return html;
+}
+
+function escapeHtml(unsafe: string) {
+ return unsafe
+ .replace(/&/g, '&')
+ .replace(//g, '>')
+ .replace(/"/g, '"')
+ .replace(/'/g, ''');
+}
+
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
// Huggingface broke their amazon elb config and no longer sends the
// x-forwarded-host header. This is a workaround.
diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts
index 98cc98b..ed1fffc 100644
--- a/src/proxy/anthropic.ts
+++ b/src/proxy/anthropic.ts
@@ -87,9 +87,8 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
body = transformAnthropicResponse(body, req);
}
- // TODO: Remove once tokenization is stable
- if (req.debug) {
- body.proxy_tokenizer_debug_info = req.debug;
+ if (req.tokenizerInfo) {
+ body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts
index c80aa6a..2f43762 100644
--- a/src/proxy/aws.ts
+++ b/src/proxy/aws.ts
@@ -73,9 +73,8 @@ const awsResponseHandler: ProxyResHandlerWithBody = async (
body = transformAwsResponse(body, req);
}
- // TODO: Remove once tokenization is stable
- if (req.debug) {
- body.proxy_tokenizer_debug_info = req.debug;
+ if (req.tokenizerInfo) {
+ body.proxy_tokenizer = req.tokenizerInfo;
}
// AWS does not confirm the model in the response, so we have to add it
diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts
index 99f8281..24cfef9 100644
--- a/src/proxy/middleware/common.ts
+++ b/src/proxy/middleware/common.ts
@@ -9,11 +9,10 @@ import { QuotaExceededError } from "./request/apply-quota-limits";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions";
const OPENAI_EMBEDDINGS_ENDPOINT = "/v1/embeddings";
+const OPENAI_IMAGE_COMPLETION_ENDPOINT = "/v1/images/generations";
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
-/** Returns true if we're making a request to a completion endpoint. */
-export function isCompletionRequest(req: Request) {
- // 99% sure this function is not needed anymore
+export function isTextGenerationRequest(req: Request) {
return (
req.method === "POST" &&
[
@@ -24,6 +23,13 @@ export function isCompletionRequest(req: Request) {
);
}
+export function isImageGenerationRequest(req: Request) {
+ return (
+ req.method === "POST" &&
+ req.path.startsWith(OPENAI_IMAGE_COMPLETION_ENDPOINT)
+ );
+}
+
export function isEmbeddingsRequest(req: Request) {
return (
req.method === "POST" && req.path.startsWith(OPENAI_EMBEDDINGS_ENDPOINT)
@@ -53,8 +59,8 @@ export function writeErrorResponse(
res.write(`data: [DONE]\n\n`);
res.end();
} else {
- if (req.debug && errorPayload.error) {
- errorPayload.error.proxy_tokenizer_debug_info = req.debug;
+ if (req.tokenizerInfo && errorPayload.error) {
+ errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
}
res.status(statusCode).json(errorPayload);
}
@@ -103,7 +109,7 @@ function classifyError(err: Error): {
code: { enabled: false },
maxErrors: 3,
transform: ({ issue, ...rest }) => {
- return `At '${rest.pathComponent}', ${issue.message}`;
+ return `At '${rest.pathComponent}': ${issue.message}`;
},
});
return { status: 400, userMessage, type: "proxy_validation_error" };
@@ -173,6 +179,8 @@ export function getCompletionFromBody(req: Request, body: Record) {
return body.completion.trim();
case "google-palm":
return body.candidates[0].output;
+ case "openai-image":
+ return body.data?.map((item: any) => item.url).join("\n");
default:
assertNever(format);
}
@@ -184,6 +192,8 @@ export function getModelFromBody(req: Request, body: Record) {
case "openai":
case "openai-text":
return body.model;
+ case "openai-image":
+ return req.body.model;
case "anthropic":
// Anthropic confirms the model in the response, but AWS Claude doesn't.
return body.model || req.body.model;
diff --git a/src/proxy/middleware/request/add-anthropic-preamble.ts b/src/proxy/middleware/request/add-anthropic-preamble.ts
index 35f3602..cdab4f2 100644
--- a/src/proxy/middleware/request/add-anthropic-preamble.ts
+++ b/src/proxy/middleware/request/add-anthropic-preamble.ts
@@ -1,5 +1,5 @@
import { AnthropicKey, Key } from "../../../shared/key-management";
-import { isCompletionRequest } from "../common";
+import { isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
/**
@@ -11,7 +11,7 @@ export const addAnthropicPreamble: ProxyRequestMiddleware = (
_proxyReq,
req
) => {
- if (!isCompletionRequest(req) || req.key?.service !== "anthropic") {
+ if (!isTextGenerationRequest(req) || req.key?.service !== "anthropic") {
return;
}
diff --git a/src/proxy/middleware/request/add-key.ts b/src/proxy/middleware/request/add-key.ts
index cb7f9e3..bbf38b7 100644
--- a/src/proxy/middleware/request/add-key.ts
+++ b/src/proxy/middleware/request/add-key.ts
@@ -1,5 +1,5 @@
import { Key, OpenAIKey, keyPool } from "../../../shared/key-management";
-import { isCompletionRequest, isEmbeddingsRequest } from "../common";
+import { isEmbeddingsRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
import { assertNever } from "../../../shared/utils";
@@ -7,18 +7,6 @@ import { assertNever } from "../../../shared/utils";
export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
let assignedKey: Key;
- if (!isCompletionRequest(req)) {
- // Horrible, horrible hack to stop the proxy from complaining about clients
- // not sending a model when they are requesting the list of models (which
- // requires a key, but obviously not a model).
-
- // I don't think this is needed anymore since models requests are no longer
- // proxied to the upstream API. Everything going through this is either a
- // completion request or a special case like OpenAI embeddings.
- req.log.warn({ path: req.path }, "addKey called on non-completion request");
- req.body.model = "gpt-3.5-turbo";
- }
-
if (!req.inboundApi || !req.outboundApi) {
const err = new Error(
"Request API format missing. Did you forget to add the request preprocessor to your router?"
@@ -54,6 +42,9 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
throw new Error(
"OpenAI Chat as an API translation target is not supported"
);
+ case "openai-image":
+ assignedKey = keyPool.get("dall-e-3");
+ break;
default:
assertNever(req.outboundApi);
}
diff --git a/src/proxy/middleware/request/apply-quota-limits.ts b/src/proxy/middleware/request/apply-quota-limits.ts
index 581de23..e7a637b 100644
--- a/src/proxy/middleware/request/apply-quota-limits.ts
+++ b/src/proxy/middleware/request/apply-quota-limits.ts
@@ -1,5 +1,5 @@
import { hasAvailableQuota } from "../../../shared/users/user-store";
-import { isCompletionRequest } from "../common";
+import { isImageGenerationRequest, isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
export class QuotaExceededError extends Error {
@@ -12,12 +12,19 @@ export class QuotaExceededError extends Error {
}
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
- if (!isCompletionRequest(req) || !req.user) {
- return;
- }
+ const subjectToQuota =
+ isTextGenerationRequest(req) || isImageGenerationRequest(req);
+ if (!subjectToQuota || !req.user) return;
const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
- if (!hasAvailableQuota(req.user.token, req.body.model, requestedTokens)) {
+ if (
+ !hasAvailableQuota({
+ userToken: req.user.token,
+ model: req.body.model,
+ api: req.outboundApi,
+ requested: requestedTokens,
+ })
+ ) {
throw new QuotaExceededError(
"You have exceeded your proxy token quota for this model.",
{
diff --git a/src/proxy/middleware/request/block-zoomer-origins.ts b/src/proxy/middleware/request/block-zoomer-origins.ts
index 93f4c87..9efa404 100644
--- a/src/proxy/middleware/request/block-zoomer-origins.ts
+++ b/src/proxy/middleware/request/block-zoomer-origins.ts
@@ -1,4 +1,3 @@
-import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
@@ -15,10 +14,6 @@ class ForbiddenError extends Error {
* stop getting emails asking for tech support.
*/
export const blockZoomerOrigins: ProxyRequestMiddleware = (_proxyReq, req) => {
- if (!isCompletionRequest(req)) {
- return;
- }
-
const origin = req.headers.origin || req.headers.referer;
if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) {
// Venus-derivatives send a test prompt to check if the proxy is working.
diff --git a/src/proxy/middleware/request/count-prompt-tokens.ts b/src/proxy/middleware/request/count-prompt-tokens.ts
index 5cb0dd4..19f1cca 100644
--- a/src/proxy/middleware/request/count-prompt-tokens.ts
+++ b/src/proxy/middleware/request/count-prompt-tokens.ts
@@ -35,14 +35,18 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service });
break;
}
+ case "openai-image": {
+ req.outputTokens = 1;
+ result = await countTokens({ req, service });
+ break;
+ }
default:
assertNever(service);
}
req.promptTokens = result.token_count;
- // TODO: Remove once token counting is stable
req.log.debug({ result: result }, "Counted prompt tokens.");
- req.debug = req.debug ?? {};
- req.debug = { ...req.debug, ...result };
-};
\ No newline at end of file
+ req.tokenizerInfo = req.tokenizerInfo ?? {};
+ req.tokenizerInfo = { ...req.tokenizerInfo, ...result };
+};
diff --git a/src/proxy/middleware/request/finalize-body.ts b/src/proxy/middleware/request/finalize-body.ts
index bc62bf5..ac90e96 100644
--- a/src/proxy/middleware/request/finalize-body.ts
+++ b/src/proxy/middleware/request/finalize-body.ts
@@ -4,6 +4,11 @@ import type { ProxyRequestMiddleware } from ".";
/** Finalize the rewritten request body. Must be the last rewriter. */
export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => {
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
+ // For image generation requests, remove stream flag.
+ if (req.outboundApi === "openai-image") {
+ delete req.body.stream;
+ }
+
const updatedBody = JSON.stringify(req.body);
proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody));
(req as any).rawBody = Buffer.from(updatedBody);
diff --git a/src/proxy/middleware/request/language-filter.ts b/src/proxy/middleware/request/language-filter.ts
index 4bb5d25..32802bc 100644
--- a/src/proxy/middleware/request/language-filter.ts
+++ b/src/proxy/middleware/request/language-filter.ts
@@ -58,6 +58,7 @@ function getPromptFromRequest(req: Request) {
)
.join("\n\n");
case "openai-text":
+ case "openai-image":
return body.prompt;
case "google-palm":
return body.prompt.text;
diff --git a/src/proxy/middleware/request/limit-completions.ts b/src/proxy/middleware/request/limit-completions.ts
index c61fee3..44f583b 100644
--- a/src/proxy/middleware/request/limit-completions.ts
+++ b/src/proxy/middleware/request/limit-completions.ts
@@ -1,12 +1,12 @@
-import { isCompletionRequest } from "../common";
+import { isTextGenerationRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
/**
- * Don't allow multiple completions to be requested to prevent abuse.
+ * Don't allow multiple text completions to be requested to prevent abuse.
* OpenAI-only, Anthropic provides no such parameter.
**/
export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => {
- if (isCompletionRequest(req) && req.outboundApi === "openai") {
+ if (isTextGenerationRequest(req) && req.outboundApi === "openai") {
const originalN = req.body?.n || 1;
req.body.n = 1;
if (originalN !== req.body.n) {
diff --git a/src/proxy/middleware/request/rewrite.ts b/src/proxy/middleware/request/rewrite.ts
index 3b62ff2..8cc078d 100644
--- a/src/proxy/middleware/request/rewrite.ts
+++ b/src/proxy/middleware/request/rewrite.ts
@@ -17,6 +17,7 @@ export const createOnProxyReqHandler = ({
// The streaming flag must be set before any other middleware runs, because
// it may influence which other middleware a particular API pipeline wants
// to run.
+ // Image generation requests can't be streamed.
req.isStreaming = req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming;
diff --git a/src/proxy/middleware/request/transform-outbound-payload.ts b/src/proxy/middleware/request/transform-outbound-payload.ts
index 5025158..f4de2a7 100644
--- a/src/proxy/middleware/request/transform-outbound-payload.ts
+++ b/src/proxy/middleware/request/transform-outbound-payload.ts
@@ -2,7 +2,7 @@ import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { OpenAIPromptMessage } from "../../../shared/tokenization";
-import { isCompletionRequest } from "../common";
+import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
import { RequestPreprocessor } from ".";
import { APIFormat } from "../../../shared/key-management";
@@ -88,6 +88,21 @@ const OpenAIV1TextCompletionSchema = z
})
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true }));
+// https://platform.openai.com/docs/api-reference/images/create
+const OpenAIV1ImagesGenerationSchema = z.object({
+ prompt: z.string().max(4000),
+ model: z.string().optional(),
+ quality: z.enum(["standard", "hd"]).optional().default("standard"),
+ n: z.number().int().min(1).max(4).optional().default(1),
+ response_format: z.enum(["url", "b64_json"]).optional(),
+ size: z
+ .enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
+ .optional()
+ .default("1024x1024"),
+ style: z.enum(["vivid", "natural"]).optional().default("vivid"),
+ user: z.string().optional(),
+});
+
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText
const PalmV1GenerateTextSchema = z.object({
model: z.string(),
@@ -110,6 +125,7 @@ const VALIDATORS: Record> = {
anthropic: AnthropicV1CompleteSchema,
openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema,
+ "openai-image": OpenAIV1ImagesGenerationSchema,
"google-palm": PalmV1GenerateTextSchema,
};
@@ -117,11 +133,10 @@ const VALIDATORS: Record> = {
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
const sameService = req.inboundApi === req.outboundApi;
const alreadyTransformed = req.retryCount > 0;
- const notTransformable = !isCompletionRequest(req);
+ const notTransformable =
+ !isTextGenerationRequest(req) && !isImageGenerationRequest(req);
- if (alreadyTransformed || notTransformable) {
- return;
- }
+ if (alreadyTransformed || notTransformable) return;
if (sameService) {
const result = VALIDATORS[req.inboundApi].safeParse(req.body);
@@ -151,6 +166,11 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
return;
}
+ if (req.inboundApi === "openai" && req.outboundApi === "openai-image") {
+ req.body = openaiToOpenaiImage(req);
+ return;
+ }
+
throw new Error(
`'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.`
);
@@ -226,6 +246,49 @@ function openaiToOpenaiText(req: Request) {
return OpenAIV1TextCompletionSchema.parse(transformed);
}
+// Takes the last chat message and uses it verbatim as the image prompt.
+function openaiToOpenaiImage(req: Request) {
+ const { body } = req;
+ const result = OpenAIV1ChatCompletionSchema.safeParse(body);
+ if (!result.success) {
+ req.log.warn(
+ { issues: result.error.issues, body },
+ "Invalid OpenAI-to-OpenAI-image request"
+ );
+ throw result.error;
+ }
+
+ const { messages } = result.data;
+ const prompt = messages.filter((m) => m.role === "user").pop()?.content;
+
+ if (body.stream) {
+ throw new Error(
+ "Streaming is not supported for image generation requests."
+ );
+ }
+
+ // Some frontends do weird things with the prompt, like prefixing it with a
+ // character name or wrapping the entire thing in quotes. We will look for
+ // the index of "Image:" and use everything after that as the prompt.
+
+ const index = prompt?.toLowerCase().indexOf("image:");
+ if (index === -1 || !prompt) {
+ throw new Error(
+ `Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).`
+ );
+ }
+
+ // TODO: Add some way to specify parameters via chat message
+ const transformed = {
+ model: body.model.includes("dall-e") ? body.model : "dall-e-3",
+ quality: "standard",
+ size: "1024x1024",
+ response_format: "url",
+ prompt: prompt.slice(index! + 6).trim(),
+ };
+ return OpenAIV1ImagesGenerationSchema.parse(transformed);
+}
+
function openaiToPalm(req: Request): z.infer {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({
diff --git a/src/proxy/middleware/request/validate-context-size.ts b/src/proxy/middleware/request/validate-context-size.ts
index 8883ccb..ee661c9 100644
--- a/src/proxy/middleware/request/validate-context-size.ts
+++ b/src/proxy/middleware/request/validate-context-size.ts
@@ -34,6 +34,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
case "google-palm":
proxyMax = BISON_MAX_CONTEXT;
break;
+ case "openai-image":
+ return;
default:
assertNever(req.outboundApi);
}
@@ -81,10 +83,10 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
"Prompt size validated"
);
- req.debug.prompt_tokens = promptTokens;
- req.debug.completion_tokens = outputTokens;
- req.debug.max_model_tokens = modelMax;
- req.debug.max_proxy_tokens = proxyMax;
+ req.tokenizerInfo.prompt_tokens = promptTokens;
+ req.tokenizerInfo.completion_tokens = outputTokens;
+ req.tokenizerInfo.max_model_tokens = modelMax;
+ req.tokenizerInfo.max_proxy_tokens = proxyMax;
};
function assertRequestHasTokenCounts(
diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts
index b398975..c68d347 100644
--- a/src/proxy/middleware/response/index.ts
+++ b/src/proxy/middleware/response/index.ts
@@ -13,13 +13,16 @@ import {
incrementTokenCount,
} from "../../../shared/users/user-store";
import { assertNever } from "../../../shared/utils";
+import { refundLastAttempt } from "../../rate-limit";
import {
getCompletionFromBody,
- isCompletionRequest,
+ isImageGenerationRequest,
+ isTextGenerationRequest,
writeErrorResponse,
} from "../common";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
+import { saveImage } from "./save-image";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
@@ -106,6 +109,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
countResponseTokens,
incrementUsage,
copyHttpHeaders,
+ saveImage,
logPrompt,
...apiMiddleware
);
@@ -285,7 +289,16 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
switch (service) {
case "openai":
case "google-palm":
- errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
+ if (errorPayload.error?.code === "content_policy_violation") {
+ errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`;
+ refundLastAttempt(req);
+ } else if (errorPayload.error?.code === "billing_hard_limit_reached") {
+ // For some reason, some models return this 400 error instead of the
+ // same 429 billing error that other models return.
+ handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
+ } else {
+ errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
+ }
break;
case "anthropic":
case "aws":
@@ -453,6 +466,7 @@ function handleOpenAIRateLimitError(
const type = errorPayload.error?.type;
switch (type) {
case "insufficient_quota":
+ case "invalid_request_error": // this is the billing_hard_limit_reached error seen in some cases
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
@@ -487,13 +501,22 @@ function handleOpenAIRateLimitError(
}
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
- if (isCompletionRequest(req)) {
+ if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
const model = req.body.model;
const tokensUsed = req.promptTokens! + req.outputTokens!;
+ req.log.debug(
+ {
+ model,
+ tokensUsed,
+ promptTokens: req.promptTokens,
+ outputTokens: req.outputTokens,
+ },
+ `Incrementing usage for model`
+ );
keyPool.incrementUsage(req.key!, model, tokensUsed);
if (req.user) {
incrementPromptCount(req.user.token);
- incrementTokenCount(req.user.token, model, tokensUsed);
+ incrementTokenCount(req.user.token, model, req.outboundApi, tokensUsed);
}
}
};
@@ -504,6 +527,12 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
_res,
body
) => {
+ if (req.outboundApi === "openai-image") {
+ req.outputTokens = req.promptTokens;
+ req.promptTokens = 0;
+ return;
+ }
+
// This function is prone to breaking if the upstream API makes even minor
// changes to the response format, especially for SSE responses. If you're
// seeing errors in this function, check the reassembled response body from
@@ -518,8 +547,8 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
{ service, tokens, prevOutputTokens: req.outputTokens },
`Counted tokens for completion`
);
- if (req.debug) {
- req.debug.completion_tokens = tokens;
+ if (req.tokenizerInfo) {
+ req.tokenizerInfo.completion_tokens = tokens;
}
req.outputTokens = tokens.token_count;
diff --git a/src/proxy/middleware/response/log-prompt.ts b/src/proxy/middleware/response/log-prompt.ts
index 0d634d3..7f3dc65 100644
--- a/src/proxy/middleware/response/log-prompt.ts
+++ b/src/proxy/middleware/response/log-prompt.ts
@@ -4,7 +4,8 @@ import { logQueue } from "../../../shared/prompt-logging";
import {
getCompletionFromBody,
getModelFromBody,
- isCompletionRequest,
+ isImageGenerationRequest,
+ isTextGenerationRequest,
} from "../common";
import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
@@ -23,11 +24,11 @@ export const logPrompt: ProxyResHandlerWithBody = async (
throw new Error("Expected body to be an object");
}
- if (!isCompletionRequest(req)) {
- return;
- }
+ const loggable =
+ isTextGenerationRequest(req) || isImageGenerationRequest(req);
+ if (!loggable) return;
- const promptPayload = getPromptForRequest(req);
+ const promptPayload = getPromptForRequest(req, responseBody);
const promptFlattened = flattenMessages(promptPayload);
const response = getCompletionFromBody(req, responseBody);
const model = getModelFromBody(req, responseBody);
@@ -46,7 +47,18 @@ type OaiMessage = {
content: string;
};
-const getPromptForRequest = (req: Request): string | OaiMessage[] => {
+type OaiImageResult = {
+ prompt: string;
+ size: string;
+ style: string;
+ quality: string;
+ revisedPrompt?: string;
+};
+
+const getPromptForRequest = (
+ req: Request,
+ responseBody: Record
+): string | OaiMessage[] | OaiImageResult => {
// Since the prompt logger only runs after the request has been proxied, we
// can assume the body has already been transformed to the target API's
// format.
@@ -55,6 +67,14 @@ const getPromptForRequest = (req: Request): string | OaiMessage[] => {
return req.body.messages;
case "openai-text":
return req.body.prompt;
+ case "openai-image":
+ return {
+ prompt: req.body.prompt,
+ size: req.body.size,
+ style: req.body.style,
+ quality: req.body.quality,
+ revisedPrompt: responseBody.data[0].revised_prompt,
+ };
case "anthropic":
return req.body.prompt;
case "google-palm":
@@ -64,9 +84,14 @@ const getPromptForRequest = (req: Request): string | OaiMessage[] => {
}
};
-const flattenMessages = (messages: string | OaiMessage[]): string => {
- if (typeof messages === "string") {
- return messages.trim();
+const flattenMessages = (
+ val: string | OaiMessage[] | OaiImageResult
+): string => {
+ if (typeof val === "string") {
+ return val.trim();
}
- return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
+ if (Array.isArray(val)) {
+ return val.map((m) => `${m.role}: ${m.content}`).join("\n");
+ }
+ return val.prompt.trim();
};
diff --git a/src/proxy/middleware/response/save-image.ts b/src/proxy/middleware/response/save-image.ts
new file mode 100644
index 0000000..937e9fc
--- /dev/null
+++ b/src/proxy/middleware/response/save-image.ts
@@ -0,0 +1,27 @@
+import { ProxyResHandlerWithBody } from "./index";
+import { mirrorGeneratedImage, OpenAIImageGenerationResult } from "../../../shared/file-storage/mirror-generated-image";
+
+export const saveImage: ProxyResHandlerWithBody = async (
+ _proxyRes,
+ req,
+ _res,
+ body,
+) => {
+ if (req.outboundApi !== "openai-image") {
+ return;
+ }
+
+ if (typeof body !== "object") {
+ throw new Error("Expected body to be an object");
+ }
+
+ if (body.data) {
+ const baseUrl = req.protocol + "://" + req.get("host");
+ const prompt = body.data[0].revised_prompt ?? req.body.prompt;
+ await mirrorGeneratedImage(
+ baseUrl,
+ prompt,
+ body as OpenAIImageGenerationResult
+ );
+ }
+};
diff --git a/src/proxy/middleware/response/streaming/event-aggregator.ts b/src/proxy/middleware/response/streaming/event-aggregator.ts
index 55f0fb3..8db3da2 100644
--- a/src/proxy/middleware/response/streaming/event-aggregator.ts
+++ b/src/proxy/middleware/response/streaming/event-aggregator.ts
@@ -33,9 +33,10 @@ export class EventAggregator {
case "anthropic":
return mergeEventsForAnthropic(this.events);
case "google-palm":
- throw new Error("Google PaLM API does not support streaming responses");
+ case "openai-image":
+ throw new Error(`SSE aggregation not supported for ${this.format}`);
default:
assertNever(this.format);
}
}
-}
\ No newline at end of file
+}
diff --git a/src/proxy/middleware/response/streaming/sse-message-transformer.ts b/src/proxy/middleware/response/streaming/sse-message-transformer.ts
index 5bd0b8e..6da55b9 100644
--- a/src/proxy/middleware/response/streaming/sse-message-transformer.ts
+++ b/src/proxy/middleware/response/streaming/sse-message-transformer.ts
@@ -99,7 +99,8 @@ function getTransformer(
? anthropicV1ToOpenAI
: anthropicV2ToOpenAI;
case "google-palm":
- throw new Error("Google PaLM does not support streaming responses");
+ case "openai-image":
+ throw new Error(`SSE transformation not supported for ${responseApi}`);
default:
assertNever(responseApi);
}
diff --git a/src/proxy/openai-image.ts b/src/proxy/openai-image.ts
new file mode 100644
index 0000000..2c5a63c
--- /dev/null
+++ b/src/proxy/openai-image.ts
@@ -0,0 +1,153 @@
+import { RequestHandler, Router, Request } from "express";
+import { createProxyMiddleware } from "http-proxy-middleware";
+import { config } from "../config";
+import { logger } from "../logger";
+import { createQueueMiddleware } from "./queue";
+import { ipLimiter } from "./rate-limit";
+import { handleProxyError } from "./middleware/common";
+import {
+ addKey,
+ applyQuotaLimits,
+ blockZoomerOrigins,
+ createPreprocessorMiddleware,
+ finalizeBody,
+ stripHeaders,
+ createOnProxyReqHandler,
+} from "./middleware/request";
+import {
+ createOnProxyResHandler,
+ ProxyResHandlerWithBody,
+} from "./middleware/response";
+import { generateModelList } from "./openai";
+import {
+ mirrorGeneratedImage,
+ OpenAIImageGenerationResult,
+} from "../shared/file-storage/mirror-generated-image";
+
+const KNOWN_MODELS = ["dall-e-2", "dall-e-3"];
+
+let modelListCache: any = null;
+let modelListValid = 0;
+const handleModelRequest: RequestHandler = (_req, res) => {
+ if (new Date().getTime() - modelListValid < 1000 * 60) return modelListCache;
+ const result = generateModelList(KNOWN_MODELS);
+ modelListCache = { object: "list", data: result };
+ modelListValid = new Date().getTime();
+ res.status(200).json(modelListCache);
+};
+
+const openaiImagesResponseHandler: ProxyResHandlerWithBody = async (
+ _proxyRes,
+ req,
+ res,
+ body
+) => {
+ if (typeof body !== "object") {
+ throw new Error("Expected body to be an object");
+ }
+
+ if (config.promptLogging) {
+ const host = req.get("host");
+ body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
+ }
+
+ if (req.inboundApi === "openai") {
+ req.log.info("Transforming OpenAI image response to OpenAI chat format");
+ body = transformResponseForChat(body as OpenAIImageGenerationResult, req);
+ }
+
+ if (req.tokenizerInfo) {
+ body.proxy_tokenizer = req.tokenizerInfo;
+ }
+
+ res.status(200).json(body);
+};
+
+/**
+ * Transforms a DALL-E image generation response into a chat response, simply
+ * embedding the image URL into the chat message as a Markdown image.
+ */
+function transformResponseForChat(
+ imageBody: OpenAIImageGenerationResult,
+ req: Request
+): Record {
+ const prompt = imageBody.data[0].revised_prompt ?? req.body.prompt;
+ const content = imageBody.data
+ .map((item) => {
+ const { url, b64_json } = item;
+ if (b64_json) {
+ return `![${prompt}](data:image/png;base64,${b64_json})`;
+ } else {
+ return `![${prompt}](${url})`;
+ }
+ })
+ .join("\n\n");
+
+ return {
+ id: "dalle-" + req.id,
+ object: "chat.completion",
+ created: Date.now(),
+ model: req.body.model,
+ usage: {
+ prompt_tokens: 0,
+ completion_tokens: req.outputTokens,
+ total_tokens: req.outputTokens,
+ },
+ choices: [
+ {
+ message: { role: "assistant", content },
+ finish_reason: "stop",
+ index: 0,
+ },
+ ],
+ };
+}
+
+const openaiImagesProxy = createQueueMiddleware({
+ proxyMiddleware: createProxyMiddleware({
+ target: "https://api.openai.com",
+ changeOrigin: true,
+ selfHandleResponse: true,
+ logger,
+ pathRewrite: {
+ "^/v1/chat/completions": "/v1/images/generations",
+ },
+ on: {
+ proxyReq: createOnProxyReqHandler({
+ pipeline: [
+ applyQuotaLimits,
+ addKey,
+ blockZoomerOrigins,
+ stripHeaders,
+ finalizeBody,
+ ],
+ }),
+ proxyRes: createOnProxyResHandler([openaiImagesResponseHandler]),
+ error: handleProxyError,
+ },
+ }),
+});
+
+const openaiImagesRouter = Router();
+openaiImagesRouter.get("/v1/models", handleModelRequest);
+openaiImagesRouter.post(
+ "/v1/images/generations",
+ ipLimiter,
+ createPreprocessorMiddleware({
+ inApi: "openai-image",
+ outApi: "openai-image",
+ service: "openai",
+ }),
+ openaiImagesProxy
+);
+openaiImagesRouter.post(
+ "/v1/chat/completions",
+ ipLimiter,
+ createPreprocessorMiddleware({
+ inApi: "openai",
+ outApi: "openai-image",
+ service: "openai",
+ }),
+ openaiImagesProxy
+);
+export const openaiImage = openaiImagesRouter;
diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts
index 68f78ce..e779d40 100644
--- a/src/proxy/openai.ts
+++ b/src/proxy/openai.ts
@@ -2,61 +2,50 @@ import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
-import {
- ModelFamily,
- OpenAIModelFamily,
- getOpenAIModelFamily,
-} from "../shared/models";
+import { getOpenAIModelFamily, ModelFamily, OpenAIModelFamily } from "../shared/models";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
- RequestPreprocessor,
addKey,
addKeyForEmbeddingsRequest,
applyQuotaLimits,
blockZoomerOrigins,
createEmbeddingsPreprocessorMiddleware,
+ createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
forceModel,
limitCompletions,
+ RequestPreprocessor,
stripHeaders,
- createOnProxyReqHandler,
} from "./middleware/request";
-import {
- createOnProxyResHandler,
- ProxyResHandlerWithBody,
-} from "./middleware/response";
+import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
+
+// https://platform.openai.com/docs/models/overview
+const KNOWN_MODELS = [
+ "gpt-4-1106-preview",
+ "gpt-4",
+ "gpt-4-0613",
+ "gpt-4-0314", // EOL 2024-06-13
+ "gpt-4-32k",
+ "gpt-4-32k-0613",
+ "gpt-4-32k-0314", // EOL 2024-06-13
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-0301", // EOL 2024-06-13
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k",
+ "gpt-3.5-turbo-16k-0613",
+ "gpt-3.5-turbo-instruct",
+ "gpt-3.5-turbo-instruct-0914",
+ "text-embedding-ada-002",
+];
let modelsCache: any = null;
let modelsCacheTime = 0;
-function getModelsResponse() {
- if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
- return modelsCache;
- }
-
- // https://platform.openai.com/docs/models/overview
- const knownModels = [
- "gpt-4-1106-preview",
- "gpt-4",
- "gpt-4-0613",
- "gpt-4-0314", // EOL 2024-06-13
- "gpt-4-32k",
- "gpt-4-32k-0613",
- "gpt-4-32k-0314", // EOL 2024-06-13
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-0301", // EOL 2024-06-13
- "gpt-3.5-turbo-0613",
- "gpt-3.5-turbo-16k",
- "gpt-3.5-turbo-16k-0613",
- "gpt-3.5-turbo-instruct",
- "gpt-3.5-turbo-instruct-0914",
- "text-embedding-ada-002",
- ];
-
+export function generateModelList(models = KNOWN_MODELS) {
let available = new Set();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "openai") continue;
@@ -67,7 +56,7 @@ function getModelsResponse() {
const allowed = new Set(config.allowedModelFamilies);
available = new Set([...available].filter((x) => allowed.has(x)));
- const models = knownModels
+ return models
.map((id) => ({
id,
object: "model",
@@ -87,15 +76,14 @@ function getModelsResponse() {
parent: null,
}))
.filter((model) => available.has(getOpenAIModelFamily(model.id)));
-
- modelsCache = { object: "list", data: models };
- modelsCacheTime = new Date().getTime();
-
- return modelsCache;
}
const handleModelRequest: RequestHandler = (_req, res) => {
- res.status(200).json(getModelsResponse());
+ if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
+ const result = generateModelList();
+ modelsCache = { object: "list", data: result };
+ modelsCacheTime = new Date().getTime();
+ res.status(200).json(modelsCache);
};
/** Handles some turbo-instruct special cases. */
@@ -137,9 +125,8 @@ const openaiResponseHandler: ProxyResHandlerWithBody = async (
body = transformTurboInstructResponse(body);
}
- // TODO: Remove once tokenization is stable
- if (req.debug) {
- body.proxy_tokenizer_debug_info = req.debug;
+ if (req.tokenizerInfo) {
+ body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
diff --git a/src/proxy/palm.ts b/src/proxy/palm.ts
index 0c9d67b..0137fd3 100644
--- a/src/proxy/palm.ts
+++ b/src/proxy/palm.ts
@@ -75,9 +75,8 @@ const palmResponseHandler: ProxyResHandlerWithBody = async (
body = transformPalmResponse(body, req);
}
- // TODO: Remove once tokenization is stable
- if (req.debug) {
- body.proxy_tokenizer_debug_info = req.debug;
+ if (req.tokenizerInfo) {
+ body.proxy_tokenizer = req.tokenizerInfo;
}
// TODO: PaLM has no streaming capability which will pose a problem here if
diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts
index f0aefab..0c92a72 100644
--- a/src/proxy/queue.ts
+++ b/src/proxy/queue.ts
@@ -12,11 +12,11 @@
*/
import type { Handler, Request } from "express";
-import { keyPool, SupportedModel } from "../shared/key-management";
+import { keyPool } from "../shared/key-management";
import {
getClaudeModelFamily,
getGooglePalmModelFamily,
- getOpenAIModelFamily,
+ getOpenAIModelFamily, MODEL_FAMILIES,
ModelFamily,
} from "../shared/models";
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
@@ -132,7 +132,7 @@ function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model family.
// Model families are typically separated on cost/rate limit boundaries so
// they should be treated as separate queues.
- const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
+ const model = req.body.model ?? "gpt-3.5-turbo";
// Weird special case for AWS because they serve multiple models from
// different vendors, even if currently only one is supported.
@@ -145,6 +145,7 @@ function getPartitionForRequest(req: Request): ModelFamily {
return getClaudeModelFamily(model);
case "openai":
case "openai-text":
+ case "openai-image":
return getOpenAIModelFamily(model);
case "google-palm":
return getGooglePalmModelFamily(model);
@@ -207,40 +208,15 @@ export function dequeue(partition: ModelFamily): Request | undefined {
function processQueue() {
// This isn't completely correct, because a key can service multiple models.
// Currently if a key is locked out on one model it will also stop servicing
- // the others, because we only track one rate limit per key.
-
- // TODO: `getLockoutPeriod` uses model names instead of model families
- // TODO: genericize this it's really ugly
- const gpt4TurboLockout = keyPool.getLockoutPeriod("gpt-4-1106");
- const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k");
- const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
- const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
- const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
- const palmLockout = keyPool.getLockoutPeriod("text-bison-001");
- const awsClaudeLockout = keyPool.getLockoutPeriod("anthropic.claude-v2");
+ // the others, because we only track rate limits for the key as a whole.
const reqs: (Request | undefined)[] = [];
- if (gpt4TurboLockout === 0) {
- reqs.push(dequeue("gpt4-turbo"));
- }
- if (gpt432kLockout === 0) {
- reqs.push(dequeue("gpt4-32k"));
- }
- if (gpt4Lockout === 0) {
- reqs.push(dequeue("gpt4"));
- }
- if (turboLockout === 0) {
- reqs.push(dequeue("turbo"));
- }
- if (claudeLockout === 0) {
- reqs.push(dequeue("claude"));
- }
- if (palmLockout === 0) {
- reqs.push(dequeue("bison"));
- }
- if (awsClaudeLockout === 0) {
- reqs.push(dequeue("aws-claude"));
- }
+ MODEL_FAMILIES.forEach((modelFamily) => {
+ const lockout = keyPool.getLockoutPeriod(modelFamily);
+ if (lockout === 0) {
+ reqs.push(dequeue(modelFamily));
+ }
+ });
reqs.filter(Boolean).forEach((req) => {
if (req?.proceed) {
diff --git a/src/proxy/rate-limit.ts b/src/proxy/rate-limit.ts
index 5ee0251..c8acf8f 100644
--- a/src/proxy/rate-limit.ts
+++ b/src/proxy/rate-limit.ts
@@ -9,8 +9,6 @@ export const SHARED_IP_ADDRESSES = new Set([
"209.97.162.44",
]);
-const RATE_LIMIT_ENABLED = Boolean(config.modelRateLimit);
-const RATE_LIMIT = Math.max(1, config.modelRateLimit);
const ONE_MINUTE_MS = 60 * 1000;
type Timestamp = number;
@@ -22,12 +20,15 @@ const exemptedRequests: Timestamp[] = [];
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
attempt > now - ONE_MINUTE_MS;
-const getTryAgainInMs = (ip: string) => {
+const getTryAgainInMs = (ip: string, type: "text" | "image") => {
const now = Date.now();
const attempts = lastAttempts.get(ip) || [];
const validAttempts = attempts.filter(isRecentAttempt(now));
- if (validAttempts.length >= RATE_LIMIT) {
+ const limit =
+ type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
+
+ if (validAttempts.length >= limit) {
return validAttempts[0] - now + ONE_MINUTE_MS;
} else {
lastAttempts.set(ip, [...validAttempts, now]);
@@ -35,12 +36,16 @@ const getTryAgainInMs = (ip: string) => {
}
};
-const getStatus = (ip: string) => {
+const getStatus = (ip: string, type: "text" | "image") => {
const now = Date.now();
const attempts = lastAttempts.get(ip) || [];
const validAttempts = attempts.filter(isRecentAttempt(now));
+
+ const limit =
+ type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
+
return {
- remaining: Math.max(0, RATE_LIMIT - validAttempts.length),
+ remaining: Math.max(0, limit - validAttempts.length),
reset: validAttempts.length > 0 ? validAttempts[0] + ONE_MINUTE_MS : now,
};
};
@@ -69,12 +74,26 @@ setInterval(clearOldExemptions, 10 * 1000);
export const getUniqueIps = () => lastAttempts.size;
+/**
+ * Can be used to manually remove the most recent attempt from an IP address,
+ * ie. in case a prompt triggered OpenAI's content filter and therefore did not
+ * result in a generation.
+ */
+export const refundLastAttempt = (req: Request) => {
+ const key = req.user?.token || req.risuToken || req.ip;
+ const attempts = lastAttempts.get(key) || [];
+ attempts.pop();
+}
+
export const ipLimiter = async (
req: Request,
res: Response,
next: NextFunction
) => {
- if (!RATE_LIMIT_ENABLED) return next();
+ const imageLimit = config.imageModelRateLimit;
+ const textLimit = config.textModelRateLimit;
+
+ if (!textLimit && !imageLimit) return next();
if (req.user?.type === "special") return next();
// Exempts Agnai.chat from IP-based rate limiting because its IPs are shared
@@ -90,24 +109,25 @@ export const ipLimiter = async (
return next();
}
+ const type = req.baseUrl + req.path ? "image" : "text";
+ const limit = type === "image" ? imageLimit : textLimit;
+
// If user is authenticated, key rate limiting by their token. Otherwise, key
// rate limiting by their IP address. Mitigates key sharing.
const rateLimitKey = req.user?.token || req.risuToken || req.ip;
- const { remaining, reset } = getStatus(rateLimitKey);
- res.set("X-RateLimit-Limit", config.modelRateLimit.toString());
+ const { remaining, reset } = getStatus(rateLimitKey, type);
+ res.set("X-RateLimit-Limit", limit.toString());
res.set("X-RateLimit-Remaining", remaining.toString());
res.set("X-RateLimit-Reset", reset.toString());
- const tryAgainInMs = getTryAgainInMs(rateLimitKey);
+ const tryAgainInMs = getTryAgainInMs(rateLimitKey, type);
if (tryAgainInMs > 0) {
res.set("Retry-After", tryAgainInMs.toString());
res.status(429).json({
error: {
type: "proxy_rate_limited",
- message: `This proxy is rate limited to ${
- config.modelRateLimit
- } prompts per minute. Please try again in ${Math.ceil(
+ message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil(
tryAgainInMs / 1000
)} seconds.`,
},
diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts
index 424f104..dc76048 100644
--- a/src/proxy/routes.ts
+++ b/src/proxy/routes.ts
@@ -2,6 +2,7 @@ import express, { Request, Response, NextFunction } from "express";
import { gatekeeper } from "./gatekeeper";
import { checkRisuToken } from "./check-risu-token";
import { openai } from "./openai";
+import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic";
import { googlePalm } from "./palm";
import { aws } from "./aws";
@@ -27,6 +28,7 @@ proxyRouter.use((req, _res, next) => {
next();
});
proxyRouter.use("/openai", addV1, openai);
+proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-palm", addV1, googlePalm);
proxyRouter.use("/aws/claude", addV1, aws);
diff --git a/src/server.ts b/src/server.ts
index 55aa95f..2b9e708 100644
--- a/src/server.ts
+++ b/src/server.ts
@@ -1,11 +1,14 @@
-import { assertConfigIsValid, config } from "./config";
+import { assertConfigIsValid, config, USER_ASSETS_DIR } from "./config";
import "source-map-support/register";
+import checkDiskSpace from "check-disk-space";
import express from "express";
import cors from "cors";
import path from "path";
import pinoHttp from "pino-http";
+import os from "os";
import childProcess from "child_process";
import { logger } from "./logger";
+import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir";
import { keyPool } from "./shared/key-management";
import { adminRouter } from "./admin/routes";
import { proxyRouter } from "./proxy/routes";
@@ -58,6 +61,8 @@ app.set("views", [
path.join(__dirname, "shared/views"),
]);
+app.use("/user_content", express.static(USER_ASSETS_DIR));
+
app.get("/health", (_req, res) => res.sendStatus(200));
app.use(cors());
app.use(checkOrigin);
@@ -99,13 +104,17 @@ async function start() {
await initTokenizers();
+ if (config.allowedModelFamilies.includes("dall-e")) {
+ await setupAssetsDir();
+ }
+
if (config.gatekeeper === "user_token") {
await initUserStore();
}
if (config.promptLogging) {
logger.info("Starting prompt logging...");
- logQueue.start();
+ await logQueue.start();
}
logger.info("Starting request queue...");
@@ -116,8 +125,12 @@ async function start() {
registerUncaughtExceptionHandler();
});
+ const diskSpace = await checkDiskSpace(
+ __dirname.startsWith("/app") ? "/app" : os.homedir()
+ );
+
logger.info(
- { build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV },
+ { build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV, diskSpace },
"Startup complete."
);
}
diff --git a/src/shared/file-storage/image-history.ts b/src/shared/file-storage/image-history.ts
new file mode 100644
index 0000000..6e8c87a
--- /dev/null
+++ b/src/shared/file-storage/image-history.ts
@@ -0,0 +1,35 @@
+
+type ImageHistory = {
+ url: string;
+ prompt: string;
+}
+
+const IMAGE_HISTORY_SIZE = 30;
+const imageHistory = new Array(IMAGE_HISTORY_SIZE);
+let imageHistoryIndex = 0;
+
+export function getImageHistory() {
+ return imageHistory.filter((url) => url);
+}
+
+export function addToImageHistory(image: ImageHistory) {
+ imageHistory[imageHistoryIndex] = image;
+ imageHistoryIndex = (imageHistoryIndex + 1) % IMAGE_HISTORY_SIZE;
+}
+
+export function getLastNImages(n: number) {
+ const result: ImageHistory[] = [];
+ let currentIndex = (imageHistoryIndex - 1 + IMAGE_HISTORY_SIZE) % IMAGE_HISTORY_SIZE;
+
+ for (let i = 0; i < n; i++) {
+ // Check if the current index is valid (not undefined).
+ if (imageHistory[currentIndex]) {
+ result.unshift(imageHistory[currentIndex]);
+ }
+
+ // Move to the previous item, wrapping around if necessary.
+ currentIndex = (currentIndex - 1 + IMAGE_HISTORY_SIZE) % IMAGE_HISTORY_SIZE;
+ }
+
+ return result;
+}
diff --git a/src/shared/file-storage/mirror-generated-image.ts b/src/shared/file-storage/mirror-generated-image.ts
new file mode 100644
index 0000000..74fd93a
--- /dev/null
+++ b/src/shared/file-storage/mirror-generated-image.ts
@@ -0,0 +1,75 @@
+import axios from "axios";
+import { promises as fs } from "fs";
+import path from "path";
+import { v4 } from "uuid";
+import { USER_ASSETS_DIR } from "../../config";
+import { logger } from "../../logger";
+import { addToImageHistory } from "./image-history";
+import sharp from "sharp";
+
+const log = logger.child({ module: "file-storage" });
+
+export type OpenAIImageGenerationResult = {
+ created: number;
+ data: {
+ revised_prompt?: string;
+ url: string;
+ b64_json: string;
+ }[];
+};
+
+async function downloadImage(url: string) {
+ const { data } = await axios.get(url, { responseType: "arraybuffer" });
+ const buffer = Buffer.from(data, "binary");
+ const newFilename = `${v4()}.png`;
+
+ const filepath = path.join(USER_ASSETS_DIR, newFilename);
+ await fs.writeFile(filepath, buffer);
+ return filepath;
+}
+
+async function saveB64Image(b64: string) {
+ const buffer = Buffer.from(b64, "base64");
+ const newFilename = `${v4()}.png`;
+
+ const filepath = path.join(USER_ASSETS_DIR, newFilename);
+ await fs.writeFile(filepath, buffer);
+ return filepath;
+}
+
+async function createThumbnail(filepath: string) {
+ const thumbnailPath = filepath.replace(/(\.[\wd_-]+)$/i, "_t.jpg");
+
+ await sharp(filepath)
+ .resize(150, 150, {
+ fit: "inside",
+ withoutEnlargement: true,
+ })
+ .toFormat("jpeg")
+ .toFile(thumbnailPath);
+
+ return thumbnailPath;
+}
+
+/**
+ * Downloads generated images and mirrors them to the user_content directory.
+ * Mutates the result object.
+ */
+export async function mirrorGeneratedImage(
+ host: string,
+ prompt: string,
+ result: OpenAIImageGenerationResult
+): Promise {
+ for (const item of result.data) {
+ let mirror: string;
+ if (item.b64_json) {
+ mirror = await saveB64Image(item.b64_json);
+ } else {
+ mirror = await downloadImage(item.url);
+ }
+ item.url = `${host}/user_content/${path.basename(mirror)}`;
+ await createThumbnail(mirror);
+ addToImageHistory({ url: item.url, prompt });
+ }
+ return result;
+}
diff --git a/src/shared/file-storage/setup-assets-dir.ts b/src/shared/file-storage/setup-assets-dir.ts
new file mode 100644
index 0000000..2958810
--- /dev/null
+++ b/src/shared/file-storage/setup-assets-dir.ts
@@ -0,0 +1,20 @@
+import { promises as fs } from "fs";
+import { logger } from "../../logger";
+import { USER_ASSETS_DIR } from "../../config";
+
+const log = logger.child({ module: "file-storage" });
+
+export async function setupAssetsDir() {
+ try {
+ log.info({ dir: USER_ASSETS_DIR }, "Setting up user assets directory");
+ await fs.mkdir(USER_ASSETS_DIR, { recursive: true });
+ const stats = await fs.stat(USER_ASSETS_DIR);
+ const mode = stats.mode | 0o666;
+ if (stats.mode !== mode) {
+ await fs.chmod(USER_ASSETS_DIR, mode);
+ }
+ } catch (e) {
+ log.error(e);
+ throw new Error("Could not create user assets directory for DALL-E image generation. You may need to update your Dockerfile to `chown` the working directory to user 1000. See the proxy docs for more information.");
+ }
+}
diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts
index a50131f..7cebe03 100644
--- a/src/shared/key-management/anthropic/provider.ts
+++ b/src/shared/key-management/anthropic/provider.ts
@@ -6,14 +6,12 @@ import type { AnthropicModelFamily } from "../../models";
import { AnthropicKeyChecker } from "./checker";
// https://docs.anthropic.com/claude/reference/selecting-a-model
-export const ANTHROPIC_SUPPORTED_MODELS = [
- "claude-instant-v1",
- "claude-instant-v1-100k",
- "claude-v1",
- "claude-v1-100k",
- "claude-2",
-] as const;
-export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number];
+export type AnthropicModel =
+ | "claude-instant-v1"
+ | "claude-instant-v1-100k"
+ | "claude-v1"
+ | "claude-v1-100k"
+ | "claude-2";
export type AnthropicKeyUpdate = Omit<
Partial,
@@ -180,7 +178,7 @@ export class AnthropicKeyProvider implements KeyProvider {
key.claudeTokens += tokens;
}
- public getLockoutPeriod(_model: AnthropicModel) {
+ public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts
index 31db723..96ceed5 100644
--- a/src/shared/key-management/aws/provider.ts
+++ b/src/shared/key-management/aws/provider.ts
@@ -6,12 +6,10 @@ import type { AwsBedrockModelFamily } from "../../models";
import { AwsKeyChecker } from "./checker";
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
-export const AWS_BEDROCK_SUPPORTED_MODELS = [
- "anthropic.claude-v1",
- "anthropic.claude-v2",
- "anthropic.claude-instant-v1",
-] as const;
-export type AwsBedrockModel = (typeof AWS_BEDROCK_SUPPORTED_MODELS)[number];
+export type AwsBedrockModel =
+ | "anthropic.claude-v1"
+ | "anthropic.claude-v2"
+ | "anthropic.claude-instant-v1";
type AwsBedrockKeyUsage = {
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
@@ -158,7 +156,7 @@ export class AwsBedrockKeyProvider implements KeyProvider {
key["aws-claudeTokens"] += tokens;
}
- public getLockoutPeriod(_model: AwsBedrockModel) {
+ public getLockoutPeriod() {
// TODO: same exact behavior for three providers, should be refactored
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts
index 3fe4597..c647a14 100644
--- a/src/shared/key-management/index.ts
+++ b/src/shared/key-management/index.ts
@@ -1,15 +1,17 @@
-import { OPENAI_SUPPORTED_MODELS, OpenAIModel } from "./openai/provider";
-import {
- ANTHROPIC_SUPPORTED_MODELS,
- AnthropicModel,
-} from "./anthropic/provider";
-import { GOOGLE_PALM_SUPPORTED_MODELS, GooglePalmModel } from "./palm/provider";
-import { AWS_BEDROCK_SUPPORTED_MODELS, AwsBedrockModel } from "./aws/provider";
+import { OpenAIModel } from "./openai/provider";
+import { AnthropicModel } from "./anthropic/provider";
+import { GooglePalmModel } from "./palm/provider";
+import { AwsBedrockModel } from "./aws/provider";
import { KeyPool } from "./key-pool";
import type { ModelFamily } from "../models";
/** The request and response format used by a model's API. */
-export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text";
+export type APIFormat =
+ | "openai"
+ | "anthropic"
+ | "google-palm"
+ | "openai-text"
+ | "openai-image";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
export type Model =
@@ -60,23 +62,12 @@ export interface KeyProvider {
update(hash: string, update: Partial): void;
available(): number;
incrementUsage(hash: string, model: string, tokens: number): void;
- getLockoutPeriod(model: Model): number;
+ getLockoutPeriod(model: ModelFamily): number;
markRateLimited(hash: string): void;
recheck(): void;
}
export const keyPool = new KeyPool();
-export const SUPPORTED_MODELS = [
- ...OPENAI_SUPPORTED_MODELS,
- ...ANTHROPIC_SUPPORTED_MODELS,
-] as const;
-export type SupportedModel = (typeof SUPPORTED_MODELS)[number];
-export {
- OPENAI_SUPPORTED_MODELS,
- ANTHROPIC_SUPPORTED_MODELS,
- GOOGLE_PALM_SUPPORTED_MODELS,
- AWS_BEDROCK_SUPPORTED_MODELS,
-};
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider";
diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts
index 07a7b4c..95e7f26 100644
--- a/src/shared/key-management/key-pool.ts
+++ b/src/shared/key-management/key-pool.ts
@@ -9,6 +9,8 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
+import { ModelFamily } from "../models";
+import { assertNever } from "../utils";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@@ -37,7 +39,7 @@ export class KeyPool {
}
public get(model: Model): Key {
- const service = this.getService(model);
+ const service = this.getServiceForModel(model);
return this.getKeyProvider(service).get(model);
}
@@ -67,7 +69,7 @@ export class KeyPool {
public available(model: Model | "all" = "all"): number {
return this.keyProviders.reduce((sum, provider) => {
const includeProvider =
- model === "all" || this.getService(model) === provider.service;
+ model === "all" || this.getServiceForModel(model) === provider.service;
return sum + (includeProvider ? provider.available() : 0);
}, 0);
}
@@ -77,9 +79,9 @@ export class KeyPool {
provider.incrementUsage(key.hash, model, tokens);
}
- public getLockoutPeriod(model: Model): number {
- const service = this.getService(model);
- return this.getKeyProvider(service).getLockoutPeriod(model);
+ public getLockoutPeriod(family: ModelFamily): number {
+ const service = this.getServiceForModelFamily(family);
+ return this.getKeyProvider(service).getLockoutPeriod(family);
}
public markRateLimited(key: Key): void {
@@ -104,8 +106,12 @@ export class KeyPool {
provider.recheck();
}
- private getService(model: Model): LLMService {
- if (model.startsWith("gpt") || model.startsWith("text-embedding-ada")) {
+ private getServiceForModel(model: Model): LLMService {
+ if (
+ model.startsWith("gpt") ||
+ model.startsWith("text-embedding-ada") ||
+ model.startsWith("dall-e")
+ ) {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
return "openai";
} else if (model.startsWith("claude-")) {
@@ -122,6 +128,25 @@ export class KeyPool {
throw new Error(`Unknown service for model '${model}'`);
}
+ private getServiceForModelFamily(modelFamily: ModelFamily): LLMService {
+ switch (modelFamily) {
+ case "gpt4":
+ case "gpt4-32k":
+ case "gpt4-turbo":
+ case "turbo":
+ case "dall-e":
+ return "openai";
+ case "claude":
+ return "anthropic";
+ case "bison":
+ return "google-palm";
+ case "aws-claude":
+ return "aws";
+ default:
+ assertNever(modelFamily);
+ }
+ }
+
private getKeyProvider(service: LLMService): KeyProvider {
return this.keyProviders.find((provider) => provider.service === service)!;
}
diff --git a/src/shared/key-management/openai/checker.ts b/src/shared/key-management/openai/checker.ts
index e7f28b8..88d7b59 100644
--- a/src/shared/key-management/openai/checker.ts
+++ b/src/shared/key-management/openai/checker.ts
@@ -95,10 +95,15 @@ export class OpenAIKeyChecker extends KeyCheckerBase {
const { data } = await axios.get(GET_MODELS_URL, opts);
const models = data.data;
- // const families: OpenAIModelFamily[] = [];
const families = new Set();
models.forEach(({ id }) => families.add(getOpenAIModelFamily(id, "turbo")));
+ // For now we remove dall-e from the list of provisioned models if only
+ // dall-e-2 is available.
+ if (families.has("dall-e") && !models.find(({ id }) => id === "dall-e-3")) {
+ families.delete("dall-e");
+ }
+
// We want to update the key's model families here, but we don't want to
// update its `lastChecked` timestamp because we need to let the liveness
// check run before we can consider the key checked.
diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts
index 4e527dd..d98bd39 100644
--- a/src/shared/key-management/openai/provider.ts
+++ b/src/shared/key-management/openai/provider.ts
@@ -15,12 +15,9 @@ export type OpenAIModel =
| "gpt-4"
| "gpt-4-32k"
| "gpt-4-1106"
- | "text-embedding-ada-002";
-export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-instruct",
- "gpt-4",
-] as const;
+ | "text-embedding-ada-002"
+ | "dall-e-2"
+ | "dall-e-3"
// Flattening model families instead of using a nested object for easier
// cloning.
@@ -127,6 +124,7 @@ export class OpenAIKeyProvider implements KeyProvider {
gpt4Tokens: 0,
"gpt4-32kTokens": 0,
"gpt4-turboTokens": 0,
+ "dall-eTokens": 0,
gpt4Rpm: 0,
};
this.keys.push(newKey);
@@ -284,10 +282,9 @@ export class OpenAIKeyProvider implements KeyProvider {
* Given a model, returns the period until a key will be available to service
* the request, or returns 0 if a key is ready immediately.
*/
- public getLockoutPeriod(model: Model = "gpt-4"): number {
- const neededFamily = getOpenAIModelFamily(model);
+ public getLockoutPeriod(family: OpenAIModelFamily): number {
const activeKeys = this.keys.filter(
- (key) => !key.isDisabled && key.modelFamilies.includes(neededFamily)
+ (key) => !key.isDisabled && key.modelFamilies.includes(family)
);
if (activeKeys.length === 0) {
@@ -335,6 +332,10 @@ export class OpenAIKeyProvider implements KeyProvider {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
key.rateLimitedAt = Date.now();
+ // DALL-E requests do not send headers telling us when the rate limit will
+ // be reset so we need to set a fallback value here. Other models will have
+ // this overwritten by the `updateRateLimits` method.
+ key.rateLimitRequestsReset = 5000;
}
public incrementUsage(keyHash: string, model: string, tokens: number) {
diff --git a/src/shared/key-management/palm/provider.ts b/src/shared/key-management/palm/provider.ts
index 5b3e3d7..dccfa08 100644
--- a/src/shared/key-management/palm/provider.ts
+++ b/src/shared/key-management/palm/provider.ts
@@ -5,11 +5,7 @@ import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models";
// https://developers.generativeai.google.com/models/language
-export const GOOGLE_PALM_SUPPORTED_MODELS = [
- "text-bison-001",
- // "chat-bison-001", no adjustable safety settings, so it's useless
-] as const;
-export type GooglePalmModel = (typeof GOOGLE_PALM_SUPPORTED_MODELS)[number];
+export type GooglePalmModel = "text-bison-001";
export type GooglePalmKeyUpdate = Omit<
Partial,
@@ -149,7 +145,7 @@ export class GooglePalmKeyProvider implements KeyProvider {
key.bisonTokens += tokens;
}
- public getLockoutPeriod(_model: GooglePalmModel) {
+ public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
diff --git a/src/shared/models.ts b/src/shared/models.ts
index f89a266..861c151 100644
--- a/src/shared/models.ts
+++ b/src/shared/models.ts
@@ -1,6 +1,8 @@
-import { logger } from "../logger";
+// Don't import anything here, this is imported by config.ts
-export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo";
+import pino from "pino";
+
+export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e";
export type AnthropicModelFamily = "claude";
export type GooglePalmModelFamily = "bison";
export type AwsBedrockModelFamily = "aws-claude";
@@ -17,6 +19,7 @@ export const MODEL_FAMILIES = ((
"gpt4",
"gpt4-32k",
"gpt4-turbo",
+ "dall-e",
"claude",
"bison",
"aws-claude",
@@ -30,8 +33,11 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4$": "gpt4",
"^gpt-3.5-turbo": "turbo",
"^text-embedding-ada-002$": "turbo",
+ "^dall-e-\\d{1}$": "dall-e",
};
+const modelLogger = pino({ level: "debug" }).child({ module: "startup" });
+
export function getOpenAIModelFamily(
model: string,
defaultFamily: OpenAIModelFamily = "gpt4"
@@ -42,14 +48,14 @@ export function getOpenAIModelFamily(
return defaultFamily;
}
-export function getClaudeModelFamily(_model: string): ModelFamily {
+export function getClaudeModelFamily(model: string): ModelFamily {
+ if (model.startsWith("anthropic.")) return getAwsBedrockModelFamily(model);
return "claude";
}
export function getGooglePalmModelFamily(model: string): ModelFamily {
if (model.match(/^\w+-bison-\d{3}$/)) return "bison";
- const stack = new Error().stack;
- logger.warn({ model, stack }, "Unmapped PaLM model family");
+ modelLogger.warn({ model }, "Could not determine Google PaLM model family");
return "bison";
}
diff --git a/src/shared/prompt-logging/backends/sheets.ts b/src/shared/prompt-logging/backends/sheets.ts
index e8c60e1..6b2450f 100644
--- a/src/shared/prompt-logging/backends/sheets.ts
+++ b/src/shared/prompt-logging/backends/sheets.ts
@@ -396,7 +396,7 @@ export const init = async (onStop: () => void) => {
await loadIndexSheet(false);
await writeIndexSheet();
} catch (e) {
- log.info("Creating new index sheet.");
+ log.warn(e, "Could not load index sheet. Creating a new one.");
await createIndexSheet();
}
};
diff --git a/src/shared/stats.ts b/src/shared/stats.ts
index aeb816e..7729a3c 100644
--- a/src/shared/stats.ts
+++ b/src/shared/stats.ts
@@ -17,6 +17,9 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
case "turbo":
cost = 0.000001;
break;
+ case "dall-e":
+ cost = 0.00001;
+ break;
case "aws-claude":
case "claude":
cost = 0.00001102;
diff --git a/src/shared/streaming.ts b/src/shared/streaming.ts
index 6611521..f3943b2 100644
--- a/src/shared/streaming.ts
+++ b/src/shared/streaming.ts
@@ -79,7 +79,8 @@ export function buildFakeSse(
};
break;
case "google-palm":
- throw new Error("PaLM not supported as an inbound API format");
+ case "openai-image":
+ throw new Error(`SSE not supported for ${req.inboundApi} requests`);
default:
assertNever(req.inboundApi);
}
@@ -92,4 +93,4 @@ export function buildFakeSse(
}
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
-}
\ No newline at end of file
+}
diff --git a/src/shared/tokenization/openai.ts b/src/shared/tokenization/openai.ts
index b71585d..a984299 100644
--- a/src/shared/tokenization/openai.ts
+++ b/src/shared/tokenization/openai.ts
@@ -78,3 +78,63 @@ export type OpenAIPromptMessage = {
content: string;
role: string;
};
+
+// Model Resolution Price
+// DALL·E 3 1024×1024 $0.040 / image
+// 1024×1792, 1792×1024 $0.080 / image
+// DALL·E 3 HD 1024×1024 $0.080 / image
+// 1024×1792, 1792×1024 $0.120 / image
+// DALL·E 2 1024×1024 $0.020 / image
+// 512×512 $0.018 / image
+// 256×256 $0.016 / image
+
+export const DALLE_TOKENS_PER_DOLLAR = 100000;
+
+/**
+ * OpenAI image generation with DALL-E doesn't use tokens but everything else
+ * in the application does. There is a fixed cost for each image generation
+ * request depending on the model and selected quality/resolution parameters,
+ * which we convert to tokens at a rate of 100000 tokens per dollar.
+ */
+export function getOpenAIImageCost(params: {
+ model: "dall-e-2" | "dall-e-3";
+ quality: "standard" | "hd";
+ resolution: "512x512" | "256x256" | "1024x1024" | "1024x1792" | "1792x1024";
+ n: number | null;
+}) {
+ const { model, quality, resolution, n } = params;
+ const usd = (() => {
+ switch (model) {
+ case "dall-e-2":
+ switch (resolution) {
+ case "512x512":
+ return 0.018;
+ case "256x256":
+ return 0.016;
+ case "1024x1024":
+ return 0.02;
+ default:
+ throw new Error("Invalid resolution");
+ }
+ case "dall-e-3":
+ switch (resolution) {
+ case "1024x1024":
+ return quality === "standard" ? 0.04 : 0.08;
+ case "1024x1792":
+ case "1792x1024":
+ return quality === "standard" ? 0.08 : 0.12;
+ default:
+ throw new Error("Invalid resolution");
+ }
+ default:
+ throw new Error("Invalid image generation model");
+ }
+ })();
+
+ const tokens = (n ?? 1) * (usd * DALLE_TOKENS_PER_DOLLAR);
+
+ return {
+ tokenizer: `openai-image cost`,
+ token_count: Math.ceil(tokens),
+ };
+}
diff --git a/src/shared/tokenization/tokenizer.ts b/src/shared/tokenization/tokenizer.ts
index 6b5491c..9d30b4e 100644
--- a/src/shared/tokenization/tokenizer.ts
+++ b/src/shared/tokenization/tokenizer.ts
@@ -8,6 +8,7 @@ import {
init as initOpenAi,
getTokenCount as getOpenAITokenCount,
OpenAIPromptMessage,
+ getOpenAIImageCost,
} from "./openai";
import { APIFormat } from "../key-management";
@@ -26,6 +27,7 @@ type TokenCountRequest = { req: Request } & (
service: "openai-text" | "anthropic" | "google-palm";
}
| { prompt?: never; completion: string; service: APIFormat }
+ | { prompt?: never; completion?: never; service: "openai-image" }
);
type TokenCountResult = {
@@ -53,6 +55,16 @@ export async function countTokens({
...getOpenAITokenCount(prompt ?? completion, req.body.model),
tokenization_duration_ms: getElapsedMs(time),
};
+ case "openai-image":
+ return {
+ ...getOpenAIImageCost({
+ model: req.body.model,
+ quality: req.body.quality,
+ resolution: req.body.size,
+ n: parseInt(req.body.n, 10) || null,
+ }),
+ tokenization_duration_ms: getElapsedMs(time),
+ };
case "google-palm":
// TODO: Can't find a tokenization library for PaLM. There is an API
// endpoint for it but it adds significant latency to the request.
diff --git a/src/shared/users/schema.ts b/src/shared/users/schema.ts
index eb9df44..492b4b8 100644
--- a/src/shared/users/schema.ts
+++ b/src/shared/users/schema.ts
@@ -7,6 +7,7 @@ export const tokenCountsSchema: ZodType = z.object({
gpt4: z.number().optional().default(0),
"gpt4-32k": z.number().optional().default(0),
"gpt4-turbo": z.number().optional().default(0),
+ "dall-e": z.number().optional().default(0),
claude: z.number().optional().default(0),
bison: z.number().optional().default(0),
"aws-claude": z.number().optional().default(0),
diff --git a/src/shared/users/user-store.ts b/src/shared/users/user-store.ts
index 556c513..8fbdb58 100644
--- a/src/shared/users/user-store.ts
+++ b/src/shared/users/user-store.ts
@@ -11,9 +11,17 @@ import admin from "firebase-admin";
import schedule from "node-schedule";
import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config";
-import { MODEL_FAMILIES, ModelFamily } from "../models";
+import {
+ getClaudeModelFamily,
+ getGooglePalmModelFamily,
+ getOpenAIModelFamily,
+ MODEL_FAMILIES,
+ ModelFamily,
+} from "../models";
import { logger } from "../../logger";
import { User, UserTokenCounts, UserUpdate } from "./schema";
+import { APIFormat } from "../key-management";
+import { assertNever } from "../utils";
const log = logger.child({ module: "users" });
@@ -22,6 +30,7 @@ const INITIAL_TOKENS: Required = {
gpt4: 0,
"gpt4-32k": 0,
"gpt4-turbo": 0,
+ "dall-e": 0,
claude: 0,
bison: 0,
"aws-claude": 0,
@@ -166,11 +175,12 @@ export function incrementPromptCount(token: string) {
export function incrementTokenCount(
token: string,
model: string,
+ api: APIFormat,
consumption: number
) {
const user = users.get(token);
if (!user) return;
- const modelFamily = getModelFamilyForQuotaUsage(model);
+ const modelFamily = getModelFamilyForQuotaUsage(model, api);
const existing = user.tokenCounts[modelFamily] ?? 0;
user.tokenCounts[modelFamily] = existing + consumption;
usersToFlush.add(token);
@@ -181,9 +191,10 @@ export function incrementTokenCount(
* to the user's list of IPs. Returns the user if they exist and are not
* disabled, otherwise returns undefined.
*/
-export function authenticate(token: string, ip: string):
- { user?: User; result: "success" | "disabled" | "not_found" | "limited" }
- {
+export function authenticate(
+ token: string,
+ ip: string
+): { user?: User; result: "success" | "disabled" | "not_found" | "limited" } {
const user = users.get(token);
if (!user) return { result: "not_found" };
if (user.disabledAt) return { result: "disabled" };
@@ -210,16 +221,22 @@ export function authenticate(token: string, ip: string):
return { user, result: "success" };
}
-export function hasAvailableQuota(
- token: string,
- model: string,
- requested: number
-) {
- const user = users.get(token);
+export function hasAvailableQuota({
+ userToken,
+ model,
+ api,
+ requested,
+}: {
+ userToken: string;
+ model: string;
+ api: APIFormat;
+ requested: number;
+}) {
+ const user = users.get(userToken);
if (!user) return false;
if (user.type === "special") return true;
- const modelFamily = getModelFamilyForQuotaUsage(model);
+ const modelFamily = getModelFamilyForQuotaUsage(model, api);
const { tokenCounts, tokenLimits } = user;
const tokenLimit = tokenLimits[modelFamily];
@@ -361,30 +378,22 @@ async function flushUsers() {
);
}
-// TODO: use key-management/models.ts for family mapping
-function getModelFamilyForQuotaUsage(model: string): ModelFamily {
- if (model.startsWith("gpt-4-1106")) {
- return "gpt4-turbo";
+function getModelFamilyForQuotaUsage(
+ model: string,
+ api: APIFormat
+): ModelFamily {
+ switch (api) {
+ case "openai":
+ case "openai-text":
+ case "openai-image":
+ return getOpenAIModelFamily(model);
+ case "anthropic":
+ return getClaudeModelFamily(model);
+ case "google-palm":
+ return getGooglePalmModelFamily(model);
+ default:
+ assertNever(api);
}
- if (model.includes("32k")) {
- return "gpt4-32k";
- }
- if (model.startsWith("gpt-4")) {
- return "gpt4";
- }
- if (model.startsWith("gpt-3.5")) {
- return "turbo";
- }
- if (model.includes("bison")) {
- return "bison";
- }
- if (model.startsWith("claude")) {
- return "claude";
- }
- if (model.startsWith("anthropic.claude")) {
- return "aws-claude";
- }
- throw new Error(`Unknown quota model family for model ${model}`);
}
function getRefreshCrontab() {
diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts
index 503bf76..bbafc84 100644
--- a/src/types/custom.d.ts
+++ b/src/types/custom.d.ts
@@ -24,8 +24,7 @@ declare global {
heartbeatInterval?: NodeJS.Timeout;
promptTokens?: number;
outputTokens?: number;
- // TODO: remove later
- debug: Record;
+ tokenizerInfo: Record;
signedRequest: HttpRequest;
}
}