From fa4bf468d2c99bb26b7a9aebb3e8baae2dfcc397 Mon Sep 17 00:00:00 2001 From: khanon Date: Sun, 1 Oct 2023 01:40:18 +0000 Subject: [PATCH] Implement AWS Bedrock support (khanon/oai-reverse-proxy!45) --- .env.example | 2 + .prettierrc | 3 +- package-lock.json | 273 +++++++++++++++++- package.json | 18 +- src/config.ts | 20 +- src/info-page.ts | 43 ++- src/proxy/anthropic.ts | 33 ++- src/proxy/aws.ts | 211 ++++++++++++++ src/proxy/middleware/common.ts | 67 +++-- src/proxy/middleware/request/add-key.ts | 5 +- .../middleware/request/count-prompt-tokens.ts | 48 +++ .../request/finalize-aws-request.ts | 26 ++ src/proxy/middleware/request/index.ts | 13 +- src/proxy/middleware/request/preprocess.ts | 33 ++- .../middleware/request/set-api-format.ts | 4 +- .../middleware/request/sign-aws-request.ts | 93 ++++++ .../request/transform-outbound-payload.ts | 21 +- ...ntext-size.ts => validate-context-size.ts} | 75 +---- .../response/handle-streamed-response.ts | 146 +++++----- src/proxy/middleware/response/index.ts | 190 +++++++----- src/proxy/middleware/response/log-prompt.ts | 16 +- .../middleware/response/sse-stream-adapter.ts | 85 ++++++ src/proxy/openai.ts | 39 ++- src/proxy/palm.ts | 20 +- src/proxy/queue.ts | 21 +- src/proxy/routes.ts | 11 + .../key-management/anthropic/provider.ts | 11 +- src/shared/key-management/aws/provider.ts | 180 ++++++++++++ src/shared/key-management/index.ts | 17 +- src/shared/key-management/key-pool.ts | 37 +-- src/shared/key-management/openai/checker.ts | 13 +- src/shared/key-management/openai/provider.ts | 13 +- src/shared/key-management/palm/provider.ts | 11 +- src/shared/models.ts | 17 +- src/shared/tokenization/tokenizer.ts | 9 +- src/shared/users/schema.ts | 1 + src/shared/users/user-store.ts | 18 +- src/types/custom.d.ts | 5 +- 38 files changed, 1438 insertions(+), 410 deletions(-) create mode 100644 src/proxy/aws.ts create mode 100644 src/proxy/middleware/request/count-prompt-tokens.ts create mode 100644 src/proxy/middleware/request/finalize-aws-request.ts create mode 100644 src/proxy/middleware/request/sign-aws-request.ts rename src/proxy/middleware/request/{check-context-size.ts => validate-context-size.ts} (57%) create mode 100644 src/proxy/middleware/response/sse-stream-adapter.ts create mode 100644 src/shared/key-management/aws/provider.ts diff --git a/.env.example b/.env.example index 31001e1..3f65447 100644 --- a/.env.example +++ b/.env.example @@ -84,8 +84,10 @@ # For Render, create a "secret file" called .env using the Environment tab. # You can add multiple API keys by separating them with a comma. +# For AWS credentials, separate the access key ID, secret key, and region with a colon. OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx +AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2 # With proxy_key gatekeeper, the password users must provide to access the API. # PROXY_KEY=your-secret-key diff --git a/.prettierrc b/.prettierrc index f5c153e..73be315 100644 --- a/.prettierrc +++ b/.prettierrc @@ -9,5 +9,6 @@ "bracketSameLine": true } } - ] + ], + "trailingComma": "es5" } diff --git a/package-lock.json b/package-lock.json index fd9a1ba..45d88cf 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,6 +10,10 @@ "license": "MIT", "dependencies": { "@anthropic-ai/tokenizer": "^0.0.4", + "@aws-crypto/sha256-js": "^5.1.0", + "@smithy/protocol-http": "^3.0.6", + "@smithy/signature-v4": "^2.0.10", + "@smithy/types": "^2.3.4", "axios": "^1.3.5", "cookie-parser": "^1.4.6", "copyfiles": "^2.4.1", @@ -22,6 +26,7 @@ "firebase-admin": "^11.10.1", "googleapis": "^122.0.0", "http-proxy-middleware": "^3.0.0-beta.1", + "lifion-aws-event-stream": "^1.0.7", "memorystore": "^1.6.7", "multer": "^1.4.5-lts.1", "node-schedule": "^2.1.1", @@ -49,9 +54,10 @@ "esbuild-register": "^3.4.2", "husky": "^8.0.3", "nodemon": "^3.0.1", + "prettier": "^3.0.3", "source-map-support": "^0.5.21", "ts-node": "^10.9.1", - "typescript": "^5.0.4" + "typescript": "^5.1.3" }, "engines": { "node": ">=18.0.0" @@ -66,6 +72,79 @@ "tiktoken": "^1.0.10" } }, + "node_modules/@aws-crypto/crc32": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/crc32/-/crc32-3.0.0.tgz", + "integrity": "sha512-IzSgsrxUcsrejQbPVilIKy16kAT52EwB6zSaI+M3xxIhKh5+aldEyvI+z6erM7TCLB2BJsFrtHjp6/4/sr+3dA==", + "dependencies": { + "@aws-crypto/util": "^3.0.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^1.11.1" + } + }, + "node_modules/@aws-crypto/crc32/node_modules/tslib": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", + "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==" + }, + "node_modules/@aws-crypto/sha256-js": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/sha256-js/-/sha256-js-5.1.0.tgz", + "integrity": "sha512-VeDxEzCJZUNikoRD7DMFZj/aITgt2VL8tf37nEJqFjUf6DU202Vf3u07W5Ip8lVDs2Pdqg2AbdoWPyjtmHU8nw==", + "dependencies": { + "@aws-crypto/util": "^5.1.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/@aws-crypto/sha256-js/node_modules/@aws-crypto/util": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/util/-/util-5.1.0.tgz", + "integrity": "sha512-TRSydv/0a4RTZYnCmbpx1F6fOfVlTostBFvLr9GCGPww2WhuIgMg5ZmWN35Wi/Cy6HuvZf82wfUN1F9gQkJ1mQ==", + "dependencies": { + "@aws-sdk/types": "^3.222.0", + "@smithy/util-utf8": "^2.0.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/util": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/util/-/util-3.0.0.tgz", + "integrity": "sha512-2OJlpeJpCR48CC8r+uKVChzs9Iungj9wkZrl8Z041DWEWvyIHILYKCPNzJghKsivj+S3mLo6BVc7mBNzdxA46w==", + "dependencies": { + "@aws-sdk/types": "^3.222.0", + "@aws-sdk/util-utf8-browser": "^3.0.0", + "tslib": "^1.11.1" + } + }, + "node_modules/@aws-crypto/util/node_modules/tslib": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", + "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==" + }, + "node_modules/@aws-sdk/types": { + "version": "3.418.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/types/-/types-3.418.0.tgz", + "integrity": "sha512-y4PQSH+ulfFLY0+FYkaK4qbIaQI9IJNMO2xsxukW6/aNoApNymN1D2FSi2la8Qbp/iPjNDKsG8suNPm9NtsWXQ==", + "dependencies": { + "@smithy/types": "^2.3.3", + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-sdk/util-utf8-browser": { + "version": "3.259.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz", + "integrity": "sha512-UvFa/vR+e19XookZF8RzFZBrw2EUkQWxiBW0yYQAhvk3C+QVGl0H3ouca8LDBlBfQKXwmW3huo/59H8rwb1wJw==", + "dependencies": { + "tslib": "^2.3.1" + } + }, "node_modules/@babel/parser": { "version": "7.22.7", "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.22.7.tgz", @@ -751,6 +830,127 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "optional": true }, + "node_modules/@smithy/eventstream-codec": { + "version": "2.0.10", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-2.0.10.tgz", + "integrity": "sha512-3SSDgX2nIsFwif6m+I4+ar4KDcZX463Noes8ekBgQHitULiWvaDZX8XqPaRQSQ4bl1vbeVXHklJfv66MnVO+lw==", + "dependencies": { + "@aws-crypto/crc32": "3.0.0", + "@smithy/types": "^2.3.4", + "@smithy/util-hex-encoding": "^2.0.0", + "tslib": "^2.5.0" + } + }, + "node_modules/@smithy/is-array-buffer": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.0.0.tgz", + "integrity": "sha512-z3PjFjMyZNI98JFRJi/U0nGoLWMSJlDjAW4QUX2WNZLas5C0CmVV6LJ01JI0k90l7FvpmixjWxPFmENSClQ7ug==", + "dependencies": { + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/protocol-http": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@smithy/protocol-http/-/protocol-http-3.0.6.tgz", + "integrity": "sha512-F0jAZzwznMmHaggiZgc7YoS08eGpmLvhVktY/Taz6+OAOHfyIqWSDNgFqYR+WHW9z5fp2XvY4mEUrQgYMQ71jw==", + "dependencies": { + "@smithy/types": "^2.3.4", + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/signature-v4": { + "version": "2.0.10", + "resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-2.0.10.tgz", + "integrity": "sha512-S6gcP4IXfO/VMswovrhxPpqvQvMal7ZRjM4NvblHSPpE5aNBYx67UkHFF3kg0hR3tJKqNpBGbxwq0gzpdHKLRA==", + "dependencies": { + "@smithy/eventstream-codec": "^2.0.10", + "@smithy/is-array-buffer": "^2.0.0", + "@smithy/types": "^2.3.4", + "@smithy/util-hex-encoding": "^2.0.0", + "@smithy/util-middleware": "^2.0.3", + "@smithy/util-uri-escape": "^2.0.0", + "@smithy/util-utf8": "^2.0.0", + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/types": { + "version": "2.3.4", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.3.4.tgz", + "integrity": "sha512-D7xlM9FOMFyFw7YnMXn9dK2KuN6+JhnrZwVt1fWaIu8hCk5CigysweeIT/H/nCo4YV+s8/oqUdLfexbkPZtvqw==", + "dependencies": { + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/util-buffer-from": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.0.0.tgz", + "integrity": "sha512-/YNnLoHsR+4W4Vf2wL5lGv0ksg8Bmk3GEGxn2vEQt52AQaPSCuaO5PM5VM7lP1K9qHRKHwrPGktqVoAHKWHxzw==", + "dependencies": { + "@smithy/is-array-buffer": "^2.0.0", + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/util-hex-encoding": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@smithy/util-hex-encoding/-/util-hex-encoding-2.0.0.tgz", + "integrity": "sha512-c5xY+NUnFqG6d7HFh1IFfrm3mGl29lC+vF+geHv4ToiuJCBmIfzx6IeHLg+OgRdPFKDXIw6pvi+p3CsscaMcMA==", + "dependencies": { + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/util-middleware": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/@smithy/util-middleware/-/util-middleware-2.0.3.tgz", + "integrity": "sha512-+FOCFYOxd2HO7v/0hkFSETKf7FYQWa08wh/x/4KUeoVBnLR4juw8Qi+TTqZI6E2h5LkzD9uOaxC9lAjrpVzaaA==", + "dependencies": { + "@smithy/types": "^2.3.4", + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/util-uri-escape": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-2.0.0.tgz", + "integrity": "sha512-ebkxsqinSdEooQduuk9CbKcI+wheijxEb3utGXkCoYQkJnwTnLbH1JXGimJtUkQwNQbsbuYwG2+aFVyZf5TLaw==", + "dependencies": { + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/util-utf8": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@smithy/util-utf8/-/util-utf8-2.0.0.tgz", + "integrity": "sha512-rctU1VkziY84n5OXe3bPNpKR001ZCME2JCaBBFgtiM2hfKbHFudc/BkMuPab8hRbLd0j3vbnBTTZ1igBf0wgiQ==", + "dependencies": { + "@smithy/util-buffer-from": "^2.0.0", + "tslib": "^2.5.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@tootallnate/once": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/@tootallnate/once/-/once-2.0.0.tgz", @@ -1706,6 +1906,37 @@ "node": ">= 0.10" } }, + "node_modules/crc": { + "version": "3.8.0", + "resolved": "https://registry.npmjs.org/crc/-/crc-3.8.0.tgz", + "integrity": "sha512-iX3mfgcTMIq3ZKLIsVFAbv7+Mc10kxabAGQb8HvjA1o3T1PIYprbakQ65d3I+2HGHt6nSKkM9PYjgoJO2KcFBQ==", + "dependencies": { + "buffer": "^5.1.0" + } + }, + "node_modules/crc/node_modules/buffer": { + "version": "5.7.1", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz", + "integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.1.13" + } + }, "node_modules/create-require": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz", @@ -3233,6 +3464,17 @@ "graceful-fs": "^4.1.9" } }, + "node_modules/lifion-aws-event-stream": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/lifion-aws-event-stream/-/lifion-aws-event-stream-1.0.7.tgz", + "integrity": "sha512-qI0O85OrV5A9rBE++oIaWFjNngk/BqjnJ+3/wdtIPLfFWhPtf+xNuWd/T8lr/wnEpKm/8HbdgYf8pKozk0dPAw==", + "dependencies": { + "crc": "^3.8.0" + }, + "engines": { + "node": ">=10.0.0" + } + }, "node_modules/limiter": { "version": "1.1.5", "resolved": "https://registry.npmjs.org/limiter/-/limiter-1.1.5.tgz", @@ -3888,6 +4130,21 @@ "node": "^10 || ^12 || >=14" } }, + "node_modules/prettier": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz", + "integrity": "sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==", + "dev": true, + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, "node_modules/process": { "version": "0.11.10", "resolved": "https://registry.npmjs.org/process/-/process-0.11.10.tgz", @@ -4683,9 +4940,9 @@ } }, "node_modules/tslib": { - "version": "2.5.0", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.5.0.tgz", - "integrity": "sha512-336iVw3rtn2BUK7ORdIAHTyxHGRIHVReokCR3XjbckJMK7ms8FysBfhLR8IXnAgy7T0PTPNBWKiH514FOW/WSg==" + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==" }, "node_modules/type-is": { "version": "1.6.18", @@ -4705,16 +4962,16 @@ "integrity": "sha512-/aCDEGatGvZ2BIk+HmLf4ifCJFwvKFNb9/JeZPMulfgFracn9QFcAf5GO8B/mweUjSoblS5In0cWhqpfs/5PQA==" }, "node_modules/typescript": { - "version": "5.0.4", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.0.4.tgz", - "integrity": "sha512-cW9T5W9xY37cc+jfEnaUvX91foxtHkza3Nw3wkoF4sSlKn0MONdkdEndig/qPBWXNkmplh3NzayQzCiHM4/hqw==", + "version": "5.1.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.1.3.tgz", + "integrity": "sha512-XH627E9vkeqhlZFQuL+UsyAXEnibT0kWR2FWONlr4sTjvxyJYnyefgrkyECLzM5NenmKzRAy2rR/OlYLA1HkZw==", "dev": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" }, "engines": { - "node": ">=12.20" + "node": ">=14.17" } }, "node_modules/uc.micro": { diff --git a/package.json b/package.json index 6e977fb..7d418f6 100644 --- a/package.json +++ b/package.json @@ -4,12 +4,12 @@ "description": "Reverse proxy for the OpenAI API", "scripts": { "build": "tsc && copyfiles -u 1 src/**/*.ejs build", - "start:dev": "nodemon --watch src --exec ts-node --transpile-only src/server.ts", - "start:watch": "nodemon --require source-map-support/register build/server.js", - "start:replit": "tsc && node build/server.js", + "prepare": "husky install", "start": "node build/server.js", - "type-check": "tsc --noEmit", - "prepare": "husky install" + "start:dev": "nodemon --watch src --exec ts-node --transpile-only src/server.ts", + "start:replit": "tsc && node build/server.js", + "start:watch": "nodemon --require source-map-support/register build/server.js", + "type-check": "tsc --noEmit" }, "engines": { "node": ">=18.0.0" @@ -18,6 +18,10 @@ "license": "MIT", "dependencies": { "@anthropic-ai/tokenizer": "^0.0.4", + "@aws-crypto/sha256-js": "^5.1.0", + "@smithy/protocol-http": "^3.0.6", + "@smithy/signature-v4": "^2.0.10", + "@smithy/types": "^2.3.4", "axios": "^1.3.5", "cookie-parser": "^1.4.6", "copyfiles": "^2.4.1", @@ -30,6 +34,7 @@ "firebase-admin": "^11.10.1", "googleapis": "^122.0.0", "http-proxy-middleware": "^3.0.0-beta.1", + "lifion-aws-event-stream": "^1.0.7", "memorystore": "^1.6.7", "multer": "^1.4.5-lts.1", "node-schedule": "^2.1.1", @@ -57,9 +62,10 @@ "esbuild-register": "^3.4.2", "husky": "^8.0.3", "nodemon": "^3.0.1", + "prettier": "^3.0.3", "source-map-support": "^0.5.21", "ts-node": "^10.9.1", - "typescript": "^5.0.4" + "typescript": "^5.1.3" }, "overrides": { "google-gax": "^3.6.1" diff --git a/src/config.ts b/src/config.ts index b3ca9a8..7936b96 100644 --- a/src/config.ts +++ b/src/config.ts @@ -20,6 +20,12 @@ type Config = { anthropicKey?: string; /** Comma-delimited list of Google PaLM API keys. */ googlePalmKey?: string; + /** + * Comma-delimited list of AWS credentials. Each credential item should be a + * colon-delimited list of access key, secret key, and AWS region. + * Example: `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2` + */ + awsCredentials?: string; /** * The proxy key to require for requests. Only applicable if the user * management mode is set to 'proxy_key', and required if so. @@ -110,7 +116,7 @@ type Config = { blockedOrigins?: string; /** Message to return when rejecting requests from blocked origins. */ blockMessage?: string; - /** Desination URL to redirect blocked requests to, for non-JSON requests. */ + /** Destination URL to redirect blocked requests to, for non-JSON requests. */ blockRedirect?: string; /** Which model families to allow requests for. Applies only to OpenAI. */ allowedModelFamilies: ModelFamily[]; @@ -142,6 +148,7 @@ export const config: Config = { openaiKey: getEnvWithDefault("OPENAI_KEY", ""), anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""), googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""), + awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""), adminKey: getEnvWithDefault("ADMIN_KEY", ""), gatekeeper: getEnvWithDefault("GATEKEEPER", "none"), @@ -168,6 +175,8 @@ export const config: Config = { "gpt4", "gpt4-32k", "claude", + "bison", + "aws-claude", ]), rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false), rejectMessage: getEnvWithDefault( @@ -196,6 +205,7 @@ export const config: Config = { "gpt4-32k": getEnvWithDefault("TOKEN_QUOTA_GPT4_32K", 0), claude: getEnvWithDefault("TOKEN_QUOTA_CLAUDE", 0), bison: getEnvWithDefault("TOKEN_QUOTA_BISON", 0), + "aws-claude": getEnvWithDefault("TOKEN_QUOTA_AWS_CLAUDE", 0), }, quotaRefreshPeriod: getEnvWithDefault("QUOTA_REFRESH_PERIOD", undefined), allowNicknameChanges: getEnvWithDefault("ALLOW_NICKNAME_CHANGES", true), @@ -288,6 +298,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [ "openaiKey", "anthropicKey", "googlePalmKey", + "awsCredentials", "proxyKey", "adminKey", "checkKeys", @@ -344,7 +355,12 @@ function getEnvWithDefault(env: string | string[], defaultValue: T): T { } try { if ( - ["OPENAI_KEY", "ANTHROPIC_KEY", "GOOGLE_PALM_KEY"].includes(String(env)) + [ + "OPENAI_KEY", + "ANTHROPIC_KEY", + "GOOGLE_PALM_KEY", + "AWS_CREDENTIALS", + ].includes(String(env)) ) { return value as unknown as T; } diff --git a/src/info-page.ts b/src/info-page.ts index fe1b337..bcab45b 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -6,6 +6,7 @@ import { AnthropicKey, GooglePalmKey, OpenAIKey, + AwsBedrockKey, keyPool, } from "./shared/key-management"; import { ModelFamily, OpenAIModelFamily } from "./shared/models"; @@ -25,6 +26,8 @@ const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey => k.service === "anthropic"; const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey => k.service === "google-palm"; +const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => + k.service === "aws"; type ModelAggregates = { active: number; @@ -43,6 +46,7 @@ type ServiceAggregates = { openaiOrgs?: number; anthropicKeys?: number; palmKeys?: number; + awsKeys?: number; proompts: number; tokens: number; tokenCost: number; @@ -85,6 +89,7 @@ function cacheInfoPageHtml(baseUrl: string) { const openaiKeys = serviceStats.get("openaiKeys") || 0; const anthropicKeys = serviceStats.get("anthropicKeys") || 0; const palmKeys = serviceStats.get("palmKeys") || 0; + const awsKeys = serviceStats.get("awsKeys") || 0; const proompts = serviceStats.get("proompts") || 0; const tokens = serviceStats.get("tokens") || 0; const tokenCost = serviceStats.get("tokenCost") || 0; @@ -98,6 +103,7 @@ function cacheInfoPageHtml(baseUrl: string) { : {}), ...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}), ...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}), + ...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}), }, proompts, tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`, @@ -105,9 +111,11 @@ function cacheInfoPageHtml(baseUrl: string) { openaiKeys, anthropicKeys, palmKeys, + awsKeys, ...(openaiKeys ? getOpenAIInfo() : {}), ...(anthropicKeys ? getAnthropicInfo() : {}), ...(palmKeys ? { "palm-bison": getPalmInfo() } : {}), + ...(awsKeys ? { "aws-claude": getAwsInfo() } : {}), config: listConfig(), build: process.env.BUILD_INFO || "dev", }; @@ -157,6 +165,7 @@ function addKeyToAggregates(k: KeyPoolKey) { increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0); increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0); increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0); + increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0); let sumTokens = 0; let sumCost = 0; @@ -167,7 +176,6 @@ function addKeyToAggregates(k: KeyPoolKey) { switch (k.service) { case "openai": - case "openai-text": if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type"); increment( serviceStats, @@ -212,6 +220,13 @@ function addKeyToAggregates(k: KeyPoolKey) { sumCost += getTokenCostUsd(family, k.bisonTokens); increment(modelStats, `${family}__tokens`, k.bisonTokens); break; + case "aws": + if (!keyIsAwsKey(k)) throw new Error("Invalid key type"); + family = "aws-claude"; + sumTokens += k["aws-claudeTokens"]; + sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]); + increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]); + break; default: assertNever(k.service); } @@ -339,6 +354,26 @@ function getPalmInfo() { }; } +function getAwsInfo() { + const awsInfo: Partial = { + active: modelStats.get("aws-claude__active") || 0, + } + + const queue = getQueueInformation("aws-claude"); + awsInfo.queued = queue.proomptersInQueue; + awsInfo.queueTime = queue.estimatedQueueTime; + + const tokens = modelStats.get("aws-claude__tokens") || 0; + const cost = getTokenCostUsd("aws-claude", tokens); + + return { + usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, + activeKeys: awsInfo.active, + proomptersInQueue: awsInfo.queued, + estimatedQueueTime: awsInfo.queueTime, + } +} + const customGreeting = fs.existsSync("greeting.md") ? fs.readFileSync("greeting.md", "utf8") : null; @@ -389,6 +424,12 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t const claudeWait = getQueueInformation("claude").estimatedQueueTime; waits.push(`**Claude:** ${claudeWait}`); } + + if (config.awsCredentials) { + const awsClaudeWait = getQueueInformation("aws-claude").estimatedQueueTime; + waits.push(`**Claude (AWS):** ${awsClaudeWait}`); + } + infoBody += "\n\n" + waits.join(" / "); if (customGreeting) { diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 6002a2a..2b3881a 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -185,24 +185,37 @@ anthropicRouter.get("/v1/models", handleModelRequest); anthropicRouter.post( "/v1/complete", ipLimiter, - createPreprocessorMiddleware({ inApi: "anthropic", outApi: "anthropic" }), + createPreprocessorMiddleware({ + inApi: "anthropic", + outApi: "anthropic", + service: "anthropic", + }), anthropicProxy ); // OpenAI-to-Anthropic compatibility endpoint. anthropicRouter.post( "/v1/chat/completions", ipLimiter, - createPreprocessorMiddleware({ inApi: "openai", outApi: "anthropic" }), + createPreprocessorMiddleware( + { inApi: "openai", outApi: "anthropic", service: "anthropic" }, + { afterTransform: [maybeReassignModel] } + ), anthropicProxy ); -// Redirect browser requests to the homepage. -anthropicRouter.get("*", (req, res, next) => { - const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); - if (isBrowser) { - res.redirect("/"); - } else { - next(); + +function maybeReassignModel(req: Request) { + const model = req.body.model; + if (!model.startsWith("gpt-")) return; + + const bigModel = process.env.CLAUDE_BIG_MODEL || "claude-v1-100k"; + const contextSize = req.promptTokens! + req.outputTokens!; + if (contextSize > 8500) { + req.log.debug( + { model: bigModel, contextSize }, + "Using Claude 100k model for OpenAI-to-Anthropic request" + ); + req.body.model = bigModel; } -}); +} export const anthropic = anthropicRouter; diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts new file mode 100644 index 0000000..3829407 --- /dev/null +++ b/src/proxy/aws.ts @@ -0,0 +1,211 @@ +import { Request, RequestHandler, Router } from "express"; +import * as http from "http"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { config } from "../config"; +import { logger } from "../logger"; +import { createQueueMiddleware } from "./queue"; +import { ipLimiter } from "./rate-limit"; +import { handleProxyError } from "./middleware/common"; +import { + applyQuotaLimits, + createPreprocessorMiddleware, + stripHeaders, + signAwsRequest, + finalizeAwsRequest, +} from "./middleware/request"; +import { + ProxyResHandlerWithBody, + createOnProxyResHandler, +} from "./middleware/response"; +import { v4 } from "uuid"; + +let modelsCache: any = null; +let modelsCacheTime = 0; + +const getModelsResponse = () => { + if (new Date().getTime() - modelsCacheTime < 1000 * 60) { + return modelsCache; + } + + if (!config.awsCredentials) return { object: "list", data: [] }; + + const variants = ["anthropic.claude-v1", "anthropic.claude-v2"]; + + const models = variants.map((id) => ({ + id, + object: "model", + created: new Date().getTime(), + owned_by: "anthropic", + permission: [], + root: "claude", + parent: null, + })); + + modelsCache = { object: "list", data: models }; + modelsCacheTime = new Date().getTime(); + + return modelsCache; +}; + +const handleModelRequest: RequestHandler = (_req, res) => { + res.status(200).json(getModelsResponse()); +}; + +const rewriteAwsRequest = ( + proxyReq: http.ClientRequest, + req: Request, + res: http.ServerResponse +) => { + // `addKey` is not used here because AWS requests have to be signed. The + // signing is an async operation so we can't do it in an http-proxy-middleware + // handler. It is instead done in the `signAwsRequest` preprocessor. + const rewriterPipeline = [applyQuotaLimits, stripHeaders, finalizeAwsRequest]; + + try { + for (const rewriter of rewriterPipeline) { + rewriter(proxyReq, req, res, {}); + } + } catch (error) { + req.log.error(error, "Error while executing proxy rewriter"); + proxyReq.destroy(error as Error); + } +}; + +/** Only used for non-streaming requests. */ +const awsResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + if (config.promptLogging) { + const host = req.get("host"); + body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`; + } + + if (req.inboundApi === "openai") { + req.log.info("Transforming AWS Claude response to OpenAI format"); + body = transformAwsResponse(body, req); + } + + // TODO: Remove once tokenization is stable + if (req.debug) { + body.proxy_tokenizer_debug_info = req.debug; + } + + // AWS does not confirm the model in the response, so we have to add it + body.model = req.body.model; + + res.status(200).json(body); +}; + +/** + * Transforms a model response from the Anthropic API to match those from the + * OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This + * is only used for non-streaming requests as streaming requests are handled + * on-the-fly. + */ +function transformAwsResponse( + awsBody: Record, + req: Request +): Record { + const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0); + return { + id: "aws-" + v4(), + object: "chat.completion", + created: Date.now(), + model: req.body.model, + usage: { + prompt_tokens: req.promptTokens, + completion_tokens: req.outputTokens, + total_tokens: totalTokens, + }, + choices: [ + { + message: { + role: "assistant", + content: awsBody.completion?.trim(), + }, + finish_reason: awsBody.stop_reason, + index: 0, + }, + ], + }; +} + +const awsProxy = createQueueMiddleware( + createProxyMiddleware({ + target: "bad-target-will-be-rewritten", + router: ({ signedRequest }) => { + if (!signedRequest) { + throw new Error("AWS requests must go through signAwsRequest first"); + } + return `${signedRequest.protocol}//${signedRequest.hostname}`; + }, + changeOrigin: true, + on: { + proxyReq: rewriteAwsRequest, + proxyRes: createOnProxyResHandler([awsResponseHandler]), + error: handleProxyError, + }, + selfHandleResponse: true, + logger, + }) +); + +const awsRouter = Router(); +// Fix paths because clients don't consistently use the /v1 prefix. +awsRouter.use((req, _res, next) => { + if (!req.path.startsWith("/v1/")) { + req.url = `/v1${req.url}`; + } + next(); +}); +awsRouter.get("/v1/models", handleModelRequest); +awsRouter.post( + "/v1/complete", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "anthropic", outApi: "anthropic", service: "aws" }, + { afterTransform: [maybeReassignModel, signAwsRequest] } + ), + awsProxy +); +// OpenAI-to-AWS Anthropic compatibility endpoint. +awsRouter.post( + "/v1/chat/completions", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "openai", outApi: "anthropic", service: "aws" }, + { afterTransform: [maybeReassignModel, signAwsRequest] } + ), + awsProxy +); + +/** + * Tries to deal with: + * - frontends sending AWS model names even when they want to use the OpenAI- + * compatible endpoint + * - frontends sending Anthropic model names that AWS doesn't recognize + * - frontends sending OpenAI model names because they expect the proxy to + * translate them + */ +function maybeReassignModel(req: Request) { + const model = req.body.model; + // User's client sent an AWS model already + if (model.includes("anthropic.claude")) return; + // User's client is sending Anthropic-style model names, check for v1 + if (model.match(/^claude-v?1/)) { + req.body.model = "anthropic.claude-v1"; + } else { + // User's client requested v2 or possibly some OpenAI model, default to v2 + req.body.model = "anthropic.claude-v2"; + } + // TODO: Handle claude-instant +} + +export const aws = awsRouter; diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 0f588d8..f5c8bd1 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -1,7 +1,6 @@ import { Request, Response } from "express"; import httpProxy from "http-proxy"; import { ZodError } from "zod"; -import { APIFormat } from "../../shared/key-management"; import { assertNever } from "../../shared/utils"; import { QuotaExceededError } from "./request/apply-quota-limits"; @@ -59,7 +58,7 @@ export function writeErrorResponse( res.write(`data: [DONE]\n\n`); res.end(); } else { - if (req.debug) { + if (req.debug && errorPayload.error) { errorPayload.error.proxy_tokenizer_debug_info = req.debug; } res.status(statusCode).json(errorPayload); @@ -132,10 +131,7 @@ export function buildFakeSseMessage( req: Request ) { let fakeEvent; - const useBackticks = !type.includes("403"); - const msgContent = useBackticks - ? `\`\`\`\n[${type}: ${string}]\n\`\`\`\n` - : `[${type}: ${string}]`; + const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`; switch (req.inboundApi) { case "openai": @@ -144,13 +140,7 @@ export function buildFakeSseMessage( object: "chat.completion.chunk", created: Date.now(), model: req.body?.model, - choices: [ - { - delta: { content: msgContent }, - index: 0, - finish_reason: type, - }, - ], + choices: [{ delta: { content }, index: 0, finish_reason: type }], }; break; case "openai-text": @@ -159,14 +149,14 @@ export function buildFakeSseMessage( object: "text_completion", created: Date.now(), choices: [ - { text: msgContent, index: 0, logprobs: null, finish_reason: type }, + { text: content, index: 0, logprobs: null, finish_reason: type }, ], model: req.body?.model, }; break; case "anthropic": fakeEvent = { - completion: msgContent, + completion: content, stop_reason: type, truncated: false, // I've never seen this be true stop: null, @@ -182,25 +172,42 @@ export function buildFakeSseMessage( return `data: ${JSON.stringify(fakeEvent)}\n\n`; } -export function getCompletionForService({ - service, - body, - req, -}: { - service: APIFormat; - body: Record; - req?: Request; -}): { completion: string; model: string } { - switch (service) { +export function getCompletionFromBody(req: Request, body: Record) { + const format = req.outboundApi; + switch (format) { case "openai": - return { completion: body.choices[0].message.content, model: body.model }; + return body.choices[0].message.content; case "openai-text": - return { completion: body.choices[0].text, model: body.model }; + return body.choices[0].text; case "anthropic": - return { completion: body.completion.trim(), model: body.model }; + if (!body.completion) { + req.log.error( + { body: JSON.stringify(body) }, + "Received empty Anthropic completion" + ); + return ""; + } + return body.completion.trim(); case "google-palm": - return { completion: body.candidates[0].output, model: req?.body.model }; + return body.candidates[0].output; default: - assertNever(service); + assertNever(format); + } +} + +export function getModelFromBody(req: Request, body: Record) { + const format = req.outboundApi; + switch (format) { + case "openai": + case "openai-text": + return body.model; + case "anthropic": + // Anthropic confirms the model in the response, but AWS Claude doesn't. + return body.model || req.body.model; + case "google-palm": + // Google doesn't confirm the model in the response. + return req.body.model; + default: + assertNever(format); } } diff --git a/src/proxy/middleware/request/add-key.ts b/src/proxy/middleware/request/add-key.ts index 76050b8..a58a79e 100644 --- a/src/proxy/middleware/request/add-key.ts +++ b/src/proxy/middleware/request/add-key.ts @@ -80,7 +80,6 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => { proxyReq.setHeader("X-API-Key", assignedKey.key); break; case "openai": - case "openai-text": const key: OpenAIKey = assignedKey as OpenAIKey; if (key.organizationId) { proxyReq.setHeader("OpenAI-Organization", key.organizationId); @@ -94,6 +93,10 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => { `?key=${assignedKey.key}` ); break; + case "aws": + throw new Error( + "add-key should not be used for AWS security credentials. Use sign-aws-request instead." + ); default: assertNever(assignedKey.service); } diff --git a/src/proxy/middleware/request/count-prompt-tokens.ts b/src/proxy/middleware/request/count-prompt-tokens.ts new file mode 100644 index 0000000..5cb0dd4 --- /dev/null +++ b/src/proxy/middleware/request/count-prompt-tokens.ts @@ -0,0 +1,48 @@ +import { RequestPreprocessor } from "./index"; +import { countTokens, OpenAIPromptMessage } from "../../../shared/tokenization"; +import { assertNever } from "../../../shared/utils"; + +/** + * Given a request with an already-transformed body, counts the number of + * tokens and assigns the count to the request. + */ +export const countPromptTokens: RequestPreprocessor = async (req) => { + const service = req.outboundApi; + let result; + + switch (service) { + case "openai": { + req.outputTokens = req.body.max_tokens; + const prompt: OpenAIPromptMessage[] = req.body.messages; + result = await countTokens({ req, prompt, service }); + break; + } + case "openai-text": { + req.outputTokens = req.body.max_tokens; + const prompt: string = req.body.prompt; + result = await countTokens({ req, prompt, service }); + break; + } + case "anthropic": { + req.outputTokens = req.body.max_tokens_to_sample; + const prompt: string = req.body.prompt; + result = await countTokens({ req, prompt, service }); + break; + } + case "google-palm": { + req.outputTokens = req.body.maxOutputTokens; + const prompt: string = req.body.prompt.text; + result = await countTokens({ req, prompt, service }); + break; + } + default: + assertNever(service); + } + + req.promptTokens = result.token_count; + + // TODO: Remove once token counting is stable + req.log.debug({ result: result }, "Counted prompt tokens."); + req.debug = req.debug ?? {}; + req.debug = { ...req.debug, ...result }; +}; \ No newline at end of file diff --git a/src/proxy/middleware/request/finalize-aws-request.ts b/src/proxy/middleware/request/finalize-aws-request.ts new file mode 100644 index 0000000..000a533 --- /dev/null +++ b/src/proxy/middleware/request/finalize-aws-request.ts @@ -0,0 +1,26 @@ +import type { ProxyRequestMiddleware } from "."; + +/** + * For AWS requests, the body is signed earlier in the request pipeline, before + * the proxy middleware. This function just assigns the path and headers to the + * proxy request. + */ +export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => { + if (!req.signedRequest) { + throw new Error("Expected req.signedRequest to be set"); + } + + // The path depends on the selected model and the assigned key's region. + proxyReq.path = req.signedRequest.path; + + // Amazon doesn't want extra headers, so we need to remove all of them and + // reassign only the ones specified in the signed request. + proxyReq.getRawHeaderNames().forEach(proxyReq.removeHeader.bind(proxyReq)); + Object.entries(req.signedRequest.headers).forEach(([key, value]) => { + proxyReq.setHeader(key, value); + }); + + // Don't use fixRequestBody here because it adds a content-length header. + // Amazon doesn't want that and it breaks the signature. + proxyReq.write(req.signedRequest.body); +}; diff --git a/src/proxy/middleware/request/index.ts b/src/proxy/middleware/request/index.ts index e35342b..f325bc0 100644 --- a/src/proxy/middleware/request/index.ts +++ b/src/proxy/middleware/request/index.ts @@ -2,14 +2,17 @@ import type { Request } from "express"; import type { ClientRequest } from "http"; import type { ProxyReqCallback } from "http-proxy"; -// Express middleware (runs before http-proxy-middleware, can be async) -export { applyQuotaLimits } from "./apply-quota-limits"; export { createPreprocessorMiddleware, createEmbeddingsPreprocessorMiddleware, } from "./preprocess"; -export { checkContextSize } from "./check-context-size"; + +// Express middleware (runs before http-proxy-middleware, can be async) +export { applyQuotaLimits } from "./apply-quota-limits"; +export { validateContextSize } from "./validate-context-size"; +export { countPromptTokens } from "./count-prompt-tokens"; export { setApiFormat } from "./set-api-format"; +export { signAwsRequest } from "./sign-aws-request"; export { transformOutboundPayload } from "./transform-outbound-payload"; // HPM middleware (runs on onProxyReq, cannot be async) @@ -17,6 +20,7 @@ export { addKey, addKeyForEmbeddingsRequest } from "./add-key"; export { addAnthropicPreamble } from "./add-anthropic-preamble"; export { blockZoomerOrigins } from "./block-zoomer-origins"; export { finalizeBody } from "./finalize-body"; +export { finalizeAwsRequest } from "./finalize-aws-request"; export { languageFilter } from "./language-filter"; export { limitCompletions } from "./limit-completions"; export { stripHeaders } from "./strip-headers"; @@ -50,3 +54,6 @@ export type RequestPreprocessor = (req: Request) => void | Promise; * request queue middleware. */ export type ProxyRequestMiddleware = ProxyReqCallback; + +export const forceModel = (model: string) => (req: Request) => + void (req.body.model = model); diff --git a/src/proxy/middleware/request/preprocess.ts b/src/proxy/middleware/request/preprocess.ts index 65333b7..e1c2f6e 100644 --- a/src/proxy/middleware/request/preprocess.ts +++ b/src/proxy/middleware/request/preprocess.ts @@ -2,24 +2,42 @@ import { RequestHandler } from "express"; import { handleInternalError } from "../common"; import { RequestPreprocessor, - checkContextSize, + validateContextSize, + countPromptTokens, setApiFormat, transformOutboundPayload, } from "."; +type RequestPreprocessorOptions = { + /** + * Functions to run before the request body is transformed between API + * formats. Use this to change the behavior of the transformation, such as for + * endpoints which can accept multiple API formats. + */ + beforeTransform?: RequestPreprocessor[]; + /** + * Functions to run after the request body is transformed and token counts are + * assigned. Use this to perform validation or other actions that depend on + * the request body being in the final API format. + */ + afterTransform?: RequestPreprocessor[]; +}; + /** * Returns a middleware function that processes the request body into the given * API format, and then sequentially runs the given additional preprocessors. */ export const createPreprocessorMiddleware = ( apiFormat: Parameters[0], - additionalPreprocessors?: RequestPreprocessor[] + { beforeTransform, afterTransform }: RequestPreprocessorOptions = {} ): RequestHandler => { const preprocessors: RequestPreprocessor[] = [ setApiFormat(apiFormat), - ...(additionalPreprocessors ?? []), + ...(beforeTransform ?? []), transformOutboundPayload, - checkContextSize, + countPromptTokens, + ...(afterTransform ?? []), + validateContextSize, ]; return async (...args) => executePreprocessors(preprocessors, args); }; @@ -29,13 +47,10 @@ export const createPreprocessorMiddleware = ( * OpenAI's embeddings API. Tokens are not counted because embeddings requests * are basically free. */ -export const createEmbeddingsPreprocessorMiddleware = ( - additionalPreprocessors?: RequestPreprocessor[] -): RequestHandler => { +export const createEmbeddingsPreprocessorMiddleware = (): RequestHandler => { const preprocessors: RequestPreprocessor[] = [ - setApiFormat({ inApi: "openai", outApi: "openai" }), + setApiFormat({ inApi: "openai", outApi: "openai", service: "openai" }), (req) => void (req.promptTokens = req.outputTokens = 0), - ...(additionalPreprocessors ?? []), ]; return async (...args) => executePreprocessors(preprocessors, args); }; diff --git a/src/proxy/middleware/request/set-api-format.ts b/src/proxy/middleware/request/set-api-format.ts index f3cd277..7346a6b 100644 --- a/src/proxy/middleware/request/set-api-format.ts +++ b/src/proxy/middleware/request/set-api-format.ts @@ -1,13 +1,15 @@ import { Request } from "express"; -import { APIFormat } from "../../../shared/key-management"; +import { APIFormat, LLMService } from "../../../shared/key-management"; import { RequestPreprocessor } from "."; export const setApiFormat = (api: { inApi: Request["inboundApi"]; outApi: APIFormat; + service: LLMService, }): RequestPreprocessor => { return (req) => { req.inboundApi = api.inApi; req.outboundApi = api.outApi; + req.service = api.service; }; }; diff --git a/src/proxy/middleware/request/sign-aws-request.ts b/src/proxy/middleware/request/sign-aws-request.ts new file mode 100644 index 0000000..b19dd8b --- /dev/null +++ b/src/proxy/middleware/request/sign-aws-request.ts @@ -0,0 +1,93 @@ +import express from "express"; +import { Sha256 } from "@aws-crypto/sha256-js"; +import { SignatureV4 } from "@smithy/signature-v4"; +import { HttpRequest } from "@smithy/protocol-http"; +import { keyPool } from "../../../shared/key-management"; +import { RequestPreprocessor } from "."; +import { AnthropicV1CompleteSchema } from "./transform-outbound-payload"; + +const AMZ_HOST = + process.env.AMZ_HOST || "invoke-bedrock.%REGION%.amazonaws.com"; + +/** + * Signs an outgoing AWS request with the appropriate headers modifies the + * request object in place to fix the path. + */ +export const signAwsRequest: RequestPreprocessor = async (req) => { + req.key = keyPool.get("anthropic.claude-v2"); + + const { model, stream } = req.body; + req.isStreaming = stream === true || stream === "true"; + + let preamble = req.body.prompt.startsWith("\n\nHuman:") ? "" : "\n\nHuman:"; + req.body.prompt = preamble + req.body.prompt; + + // AWS supports only a subset of Anthropic's parameters and is more strict + // about unknown parameters. + // TODO: This should happen in transform-outbound-payload.ts + const strippedParams = AnthropicV1CompleteSchema.pick({ + prompt: true, + max_tokens_to_sample: true, + stop_sequences: true, + temperature: true, + top_k: true, + top_p: true, + }).parse(req.body); + + const credential = getCredentialParts(req); + const host = AMZ_HOST.replace("%REGION%", credential.region); + + // Uses the AWS SDK to sign a request, then modifies our HPM proxy request + // with the headers generated by the SDK. + const newRequest = new HttpRequest({ + method: "POST", + protocol: "https:", + hostname: host, + path: `/model/${model}/invoke${stream ? "-with-response-stream" : ""}`, + headers: { + ["Host"]: host, + ["content-type"]: "application/json", + }, + body: JSON.stringify(strippedParams), + }); + + if (stream) { + newRequest.headers["x-amzn-bedrock-accept"] = "application/json"; + } else { + newRequest.headers["accept"] = "*/*"; + } + + req.signedRequest = await sign(newRequest, getCredentialParts(req)); +}; + +type Credential = { + accessKeyId: string; + secretAccessKey: string; + region: string; +}; +function getCredentialParts(req: express.Request): Credential { + const [accessKeyId, secretAccessKey, region] = req.key!.key.split(":"); + + if (!accessKeyId || !secretAccessKey || !region) { + req.log.error( + { key: req.key!.hash }, + "AWS_CREDENTIALS isn't correctly formatted; refer to the docs" + ); + throw new Error("The key assigned to this request is invalid."); + } + + return { accessKeyId, secretAccessKey, region }; +} + +async function sign(request: HttpRequest, credential: Credential) { + const { accessKeyId, secretAccessKey, region } = credential; + + const signer = new SignatureV4({ + sha256: Sha256, + credentials: { accessKeyId, secretAccessKey }, + region, + service: "bedrock", + }); + + return signer.sign(request); +} diff --git a/src/proxy/middleware/request/transform-outbound-payload.ts b/src/proxy/middleware/request/transform-outbound-payload.ts index 53edc45..7659ab6 100644 --- a/src/proxy/middleware/request/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/transform-outbound-payload.ts @@ -10,8 +10,8 @@ const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic; const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI; // https://console.anthropic.com/docs/api/reference#-v1-complete -const AnthropicV1CompleteSchema = z.object({ - model: z.string().regex(/^claude-/, "Model must start with 'claude-'"), +export const AnthropicV1CompleteSchema = z.object({ + model: z.string(), prompt: z.string({ required_error: "No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?", @@ -23,14 +23,14 @@ const AnthropicV1CompleteSchema = z.object({ stop_sequences: z.array(z.string()).optional(), stream: z.boolean().optional().default(false), temperature: z.coerce.number().optional().default(1), - top_k: z.coerce.number().optional().default(-1), - top_p: z.coerce.number().optional().default(-1), + top_k: z.coerce.number().optional(), + top_p: z.coerce.number().optional(), metadata: z.any().optional(), }); // https://platform.openai.com/docs/api-reference/chat/create const OpenAIV1ChatCompletionSchema = z.object({ - model: z.string().regex(/^gpt/, "Model must start with 'gpt-'"), + model: z.string(), messages: z.array( z.object({ role: z.enum(["system", "user", "assistant"]), @@ -89,7 +89,7 @@ const OpenAIV1TextCompletionSchema = z // https://developers.generativeai.google/api/rest/generativelanguage/models/generateText const PalmV1GenerateTextSchema = z.object({ - model: z.string().regex(/^\w+-bison-\d{3}$/), + model: z.string(), prompt: z.object({ text: z.string() }), temperature: z.number().optional(), maxOutputTokens: z.coerce @@ -159,7 +159,7 @@ function openaiToAnthropic(req: Request) { const { body } = req; const result = OpenAIV1ChatCompletionSchema.safeParse(body); if (!result.success) { - req.log.error( + req.log.warn( { issues: result.error.issues, body }, "Invalid OpenAI-to-Anthropic request" ); @@ -208,7 +208,7 @@ function openaiToOpenaiText(req: Request) { const { body } = req; const result = OpenAIV1ChatCompletionSchema.safeParse(body); if (!result.success) { - req.log.error( + req.log.warn( { issues: result.error.issues, body }, "Invalid OpenAI-to-OpenAI-text request" ); @@ -227,8 +227,7 @@ function openaiToOpenaiText(req: Request) { stops = [...new Set(stops)]; const transformed = { ...rest, prompt: prompt, stop: stops }; - const validated = OpenAIV1TextCompletionSchema.parse(transformed); - return validated; + return OpenAIV1TextCompletionSchema.parse(transformed); } function openaiToPalm(req: Request): z.infer { @@ -238,7 +237,7 @@ function openaiToPalm(req: Request): z.infer { model: "gpt-3.5-turbo", }); if (!result.success) { - req.log.error( + req.log.warn( { issues: result.error.issues, body }, "Invalid OpenAI-to-Palm request" ); diff --git a/src/proxy/middleware/request/check-context-size.ts b/src/proxy/middleware/request/validate-context-size.ts similarity index 57% rename from src/proxy/middleware/request/check-context-size.ts rename to src/proxy/middleware/request/validate-context-size.ts index 094067c..203bf1b 100644 --- a/src/proxy/middleware/request/check-context-size.ts +++ b/src/proxy/middleware/request/validate-context-size.ts @@ -1,9 +1,8 @@ import { Request } from "express"; import { z } from "zod"; import { config } from "../../../config"; -import { OpenAIPromptMessage, countTokens } from "../../../shared/tokenization"; -import { RequestPreprocessor } from "."; import { assertNever } from "../../../shared/utils"; +import { RequestPreprocessor } from "."; const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic; const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI; @@ -16,51 +15,7 @@ const BISON_MAX_CONTEXT = 8100; * This preprocessor should run after any preprocessor that transforms the * request body. */ -export const checkContextSize: RequestPreprocessor = async (req) => { - const service = req.outboundApi; - let result; - - switch (service) { - case "openai": { - req.outputTokens = req.body.max_tokens; - const prompt: OpenAIPromptMessage[] = req.body.messages; - result = await countTokens({ req, prompt, service }); - break; - } - case "openai-text": { - req.outputTokens = req.body.max_tokens; - const prompt: string = req.body.prompt; - result = await countTokens({ req, prompt, service }); - break; - } - case "anthropic": { - req.outputTokens = req.body.max_tokens_to_sample; - const prompt: string = req.body.prompt; - result = await countTokens({ req, prompt, service }); - break; - } - case "google-palm": { - req.outputTokens = req.body.maxOutputTokens; - const prompt: string = req.body.prompt.text; - result = await countTokens({ req, prompt, service }); - break; - } - default: - assertNever(service); - } - - req.promptTokens = result.token_count; - - // TODO: Remove once token counting is stable - req.log.debug({ result: result }, "Counted prompt tokens."); - req.debug = req.debug ?? {}; - req.debug = { ...req.debug, ...result }; - - maybeTranslateOpenAIModel(req); - validateContextSize(req); -}; - -function validateContextSize(req: Request) { +export const validateContextSize: RequestPreprocessor = async (req) => { assertRequestHasTokenCounts(req); const promptTokens = req.promptTokens; const outputTokens = req.outputTokens; @@ -125,7 +80,7 @@ function validateContextSize(req: Request) { req.debug.completion_tokens = outputTokens; req.debug.max_model_tokens = modelMax; req.debug.max_proxy_tokens = proxyMax; -} +}; function assertRequestHasTokenCounts( req: Request @@ -137,27 +92,3 @@ function assertRequestHasTokenCounts( .nonstrict() .parse({ promptTokens: req.promptTokens, outputTokens: req.outputTokens }); } - -/** - * For OpenAI-to-Anthropic requests, users can't specify the model, so we need - * to pick one based on the final context size. Ideally this would happen in - * the `transformOutboundPayload` preprocessor, but we don't have the context - * size at that point (and need a transformed body to calculate it). - */ -function maybeTranslateOpenAIModel(req: Request) { - if (req.inboundApi !== "openai" || req.outboundApi !== "anthropic") { - return; - } - - const bigModel = process.env.CLAUDE_BIG_MODEL || "claude-v1-100k"; - const contextSize = req.promptTokens! + req.outputTokens!; - - if (contextSize > 8500) { - req.log.debug( - { model: bigModel, contextSize }, - "Using Claude 100k model for OpenAI-to-Anthropic request" - ); - req.body.model = bigModel; - } - // Small model is the default already set in `transformOutboundPayload` -} diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index f2f0a75..605a5c5 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -3,6 +3,7 @@ import * as http from "http"; import { buildFakeSseMessage } from "../common"; import { RawResponseBodyHandler, decodeResponseBody } from "."; import { assertNever } from "../../../shared/utils"; +import { ServerSentEventStreamAdapter } from "./sse-stream-adapter"; type OpenAiChatCompletionResponse = { id: string; @@ -82,6 +83,11 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( return decodeResponseBody(proxyRes, req, res); } + req.log.debug( + { headers: proxyRes.headers, key: key.hash }, + `Received SSE headers.` + ); + return new Promise((resolve, reject) => { req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`); @@ -97,75 +103,50 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( res.flushHeaders(); } - const originalEvents: string[] = []; - let partialMessage = ""; + const adapter = new ServerSentEventStreamAdapter({ + isAwsStream: + proxyRes.headers["content-type"] === + "application/vnd.amazon.eventstream", + }); + + const events: string[] = []; let lastPosition = 0; let eventCount = 0; - type ProxyResHandler = (...args: T[]) => void; - function withErrorHandling(fn: ProxyResHandler) { - return (...args: T[]) => { - try { - fn(...args); - } catch (error) { - proxyRes.emit("error", error); - } - }; - } + proxyRes.pipe(adapter); - proxyRes.on( - "data", - withErrorHandling((chunk: Buffer) => { - // We may receive multiple (or partial) SSE messages in a single chunk, - // so we need to buffer and emit seperate stream events for full - // messages so we can parse/transform them properly. - const str = chunk.toString(); - - // Anthropic uses CRLF line endings (out-of-spec btw) - const fullMessages = (partialMessage + str).split(/\r?\n\r?\n/); - partialMessage = fullMessages.pop() || ""; - - for (const message of fullMessages) { - proxyRes.emit("full-sse-event", message); - } - }) - ); - - proxyRes.on( - "full-sse-event", - withErrorHandling((data) => { - originalEvents.push(data); + adapter.on("data", (chunk: any) => { + try { const { event, position } = transformEvent({ - data, + data: chunk.toString(), requestApi: req.inboundApi, responseApi: req.outboundApi, lastPosition, index: eventCount++, }); + events.push(event); lastPosition = position; res.write(event + "\n\n"); - }) - ); + } catch (err) { + adapter.emit("error", err); + } + }); - proxyRes.on( - "end", - withErrorHandling(() => { - let finalBody = convertEventsToFinalResponse(originalEvents, req); + adapter.on("end", () => { + try { req.log.info({ key: key.hash }, `Finished proxying SSE stream.`); + const finalBody = convertEventsToFinalResponse(events, req); res.end(); resolve(finalBody); - }) - ); + } catch (err) { + adapter.emit("error", err); + } + }); - proxyRes.on("error", (err) => { + adapter.on("error", (err) => { req.log.error({ error: err, key: key.hash }, `Mid-stream error.`); - const fakeErrorEvent = buildFakeSseMessage( - "mid-stream-error", - err.message, - req - ); - res.write(`data: ${JSON.stringify(fakeErrorEvent)}\n\n`); - res.write("data: [DONE]\n\n"); + const errorEvent = buildFakeSseMessage("stream-error", err.message, req); + res.write(`data: ${JSON.stringify(errorEvent)}\n\ndata: [DONE]\n\n`); res.end(); reject(err); }); @@ -197,8 +178,6 @@ function transformEvent(params: SSETransformationArgs) { case "openai->anthropic": // TODO: handle new anthropic streaming format return transformV1AnthropicEventToOpenAI(params); - case "openai->google-palm": - return transformPalmEventToOpenAI(params); default: throw new Error(`Unsupported streaming API transformation. ${trans}`); } @@ -288,11 +267,6 @@ function transformV1AnthropicEventToOpenAI(params: SSETransformationArgs) { }; } -function transformPalmEventToOpenAI({ data }: SSETransformationArgs) { - throw new Error("PaLM streaming not yet supported."); - return { position: -1, event: data }; -} - /** Copy headers, excluding ones we're already setting for the SSE response. */ function copyHeaders(proxyRes: http.IncomingMessage, res: Response) { const toOmit = [ @@ -366,7 +340,7 @@ function convertEventsToFinalResponse(events: string[], req: Request) { choices: [], // TODO: merge logprobs }; - merged = events.reduce((acc, event, i) => { + merged = events.reduce((acc, event) => { if (!event.startsWith("data: ")) return acc; if (event === "data: [DONE]") return acc; @@ -390,16 +364,37 @@ function convertEventsToFinalResponse(events: string[], req: Request) { return merged; } case "anthropic": { - /* - * Full complete responses from Anthropic are conveniently just the same as - * the final SSE event before the "DONE" event, so we can reuse that - */ - const lastEvent = events[events.length - 2].toString(); - const data = JSON.parse( - lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length) - ); - const final: AnthropicCompletionResponse = { ...data, log_id: req.id }; - return final; + if (req.headers["anthropic-version"] === "2023-01-01") { + return convertAnthropicV1(events, req); + } + + let merged: AnthropicCompletionResponse = { + completion: "", + stop_reason: "", + truncated: false, + stop: null, + model: req.body.model, + log_id: "", + exception: null, + } + + merged = events.reduce((acc, event) => { + if (!event.startsWith("data: ")) return acc; + if (event === "data: [DONE]") return acc; + + const data = JSON.parse(event.slice("data: ".length)); + + return { + completion: acc.completion + data.completion, + stop_reason: data.stop_reason, + truncated: data.truncated, + stop: data.stop, + log_id: data.log_id, + exception: data.exception, + model: acc.model, + }; + }, merged); + return merged; } case "google-palm": { throw new Error("PaLM streaming not yet supported."); @@ -408,3 +403,16 @@ function convertEventsToFinalResponse(events: string[], req: Request) { assertNever(req.outboundApi); } } + +/** Older Anthropic streaming format which sent full completion each time. */ +function convertAnthropicV1( + events: string[], + req: Request +) { + const lastEvent = events[events.length - 2].toString(); + const data = JSON.parse( + lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length) + ); + const final: AnthropicCompletionResponse = { ...data, log_id: req.id }; + return final; +} diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index a432bdc..2a11427 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -12,7 +12,7 @@ import { incrementTokenCount, } from "../../../shared/users/user-store"; import { - getCompletionForService, + getCompletionFromBody, isCompletionRequest, writeErrorResponse, } from "../common"; @@ -173,7 +173,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( throw err; } - const promise = new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { let chunks: Buffer[] = []; proxyRes.on("data", (chunk) => chunks.push(chunk)); proxyRes.on("end", async () => { @@ -209,10 +209,14 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( } }); }); - return promise; }; -// TODO: This is too specific to OpenAI's error responses. +type ProxiedErrorPayload = { + error?: Record; + message?: string; + proxy_note?: string; +}; + /** * Handles non-2xx responses from the upstream service. If the proxied response * is an error, this will respond to the client with an error payload and throw @@ -233,27 +237,19 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( return; } - let errorPayload: Record; - // Subtract 1 from available keys because if this message is being shown, - // it's because the key is about to be disabled. - const availableKeys = keyPool.available(req.outboundApi) - 1; - const tryAgainMessage = Boolean(availableKeys) - ? `There are ${availableKeys} more keys available; try your request again.` - : "There are no more keys available."; + let errorPayload: ProxiedErrorPayload; + const tryAgainMessage = keyPool.available(req.body?.model) + ? `There may be more keys available for this model; try again in a few seconds.` + : "There are no more keys available for this model."; try { - if (typeof body === "object") { - errorPayload = body; - } else { - throw new Error("Received unparsable error response from upstream."); - } - } catch (parseError: any) { + assertJsonResponse(body); + errorPayload = body; + } catch (parseError) { + // Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy + const hash = req.key?.hash; const statusMessage = proxyRes.statusMessage || "Unknown error"; - // Likely Bad Gateway or Gateway Timeout from reverse proxy/load balancer - logger.warn( - { statusCode, statusMessage, key: req.key?.hash }, - parseError.message - ); + logger.warn({ statusCode, statusMessage, key: hash }, parseError.message); const errorObject = { statusCode, @@ -265,53 +261,76 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( throw new Error(parseError.message); } + const errorType = + errorPayload.error?.code || + errorPayload.error?.type || + getAwsErrorType(proxyRes.headers["x-amzn-errortype"]); + logger.warn( - { - statusCode, - type: errorPayload.error?.code, - errorPayload, - key: req.key?.hash, - }, + { statusCode, type: errorType, errorPayload, key: req.key?.hash }, `Received error response from upstream. (${proxyRes.statusMessage})` ); + const service = req.key!.service; + if (service === "aws") { + // Try to standardize the error format for AWS + errorPayload.error = { message: errorPayload.message, type: errorType }; + delete errorPayload.message; + } + if (statusCode === 400) { - // Bad request (likely prompt is too long) - switch (req.outboundApi) { + // Bad request. For OpenAI, this is usually due to prompt length. + // For Anthropic, this is usually due to missing preamble. + switch (service) { case "openai": - case "openai-text": case "google-palm": errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`; break; case "anthropic": + case "aws": maybeHandleMissingPreambleError(req, errorPayload); break; default: - assertNever(req.outboundApi); + assertNever(service); } } else if (statusCode === 401) { // Key is invalid or was revoked keyPool.disable(req.key!, "revoked"); errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`; + } else if (statusCode === 403) { + // Amazon is the only service that returns 403. + switch (errorType) { + case "UnrecognizedClientException": + // Key is invalid. + keyPool.disable(req.key!, "revoked"); + errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`; + break; + case "AccessDeniedException": + errorPayload.proxy_note = `API key doesn't have access to the requested resource.`; + break; + default: + errorPayload.proxy_note = `Received 403 error. Key may be invalid.`; + } } else if (statusCode === 429) { - switch (req.outboundApi) { + switch (service) { case "openai": - case "openai-text": handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload); break; case "anthropic": handleAnthropicRateLimitError(req, errorPayload); break; + case "aws": + handleAwsRateLimitError(req, errorPayload); + break; case "google-palm": throw new Error("Rate limit handling not implemented for PaLM"); default: - assertNever(req.outboundApi); + assertNever(service); } } else if (statusCode === 404) { // Most likely model not found - switch (req.outboundApi) { + switch (service) { case "openai": - case "openai-text": if (errorPayload.error?.code === "model_not_found") { const requestedModel = req.body.model; const modelFamily = getOpenAIModelFamily(requestedModel); @@ -328,8 +347,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( case "google-palm": errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`; break; + case "aws": + errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`; + break; default: - assertNever(req.outboundApi); + assertNever(service); } } else { errorPayload.proxy_note = `Unrecognized error from upstream service.`; @@ -368,7 +390,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( */ function maybeHandleMissingPreambleError( req: Request, - errorPayload: Record + errorPayload: ProxiedErrorPayload ) { if ( errorPayload.error?.type === "invalid_request_error" && @@ -388,7 +410,7 @@ function maybeHandleMissingPreambleError( function handleAnthropicRateLimitError( req: Request, - errorPayload: Record + errorPayload: ProxiedErrorPayload ) { if (errorPayload.error?.type === "rate_limit_error") { keyPool.markRateLimited(req.key!); @@ -399,35 +421,55 @@ function handleAnthropicRateLimitError( } } +function handleAwsRateLimitError( + req: Request, + errorPayload: ProxiedErrorPayload +) { + const errorType = errorPayload.error?.type; + switch (errorType) { + case "ThrottlingException": + keyPool.markRateLimited(req.key!); + reenqueueRequest(req); + throw new RetryableError("AWS rate-limited request re-enqueued."); + case "ModelNotReadyException": + errorPayload.proxy_note = `The requested model is overloaded. Try again in a few seconds.`; + break; + default: + errorPayload.proxy_note = `Unrecognized rate limit error from AWS. (${errorType})`; + } +} + function handleOpenAIRateLimitError( req: Request, tryAgainMessage: string, - errorPayload: Record + errorPayload: ProxiedErrorPayload ): Record { const type = errorPayload.error?.type; - if (type === "insufficient_quota") { - // Billing quota exceeded (key is dead, disable it) - keyPool.disable(req.key!, "quota"); - errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`; - } else if (type === "access_terminated") { - // Account banned (key is dead, disable it) - keyPool.disable(req.key!, "revoked"); - errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`; - } else if (type === "billing_not_active") { - // Billing is not active (key is dead, disable it) - keyPool.disable(req.key!, "revoked"); - errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`; - } else if (type === "requests" || type === "tokens") { - // Per-minute request or token rate limit is exceeded, which we can retry - keyPool.markRateLimited(req.key!); - // I'm aware this is confusing -- throwing this class of error will cause - // the proxy response handler to return without terminating the request, - // so that it can be placed back in the queue. - reenqueueRequest(req); - throw new RetryableError("Rate-limited request re-enqueued."); - } else { - // OpenAI probably overloaded - errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`; + switch (type) { + case "insufficient_quota": + // Billing quota exceeded (key is dead, disable it) + keyPool.disable(req.key!, "quota"); + errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`; + break; + case "access_terminated": + // Account banned (key is dead, disable it) + keyPool.disable(req.key!, "revoked"); + errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`; + break; + case "billing_not_active": + // Key valid but account billing is delinquent + keyPool.disable(req.key!, "quota"); + errorPayload.proxy_note = `Assigned key has been disabled due to delinquent billing. ${tryAgainMessage}`; + break; + case "requests": + case "tokens": + // Per-minute request or token rate limit is exceeded, which we can retry + keyPool.markRateLimited(req.key!); + reenqueueRequest(req); + throw new RetryableError("Rate-limited request re-enqueued."); + default: + errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`; + break; } return errorPayload; } @@ -455,12 +497,9 @@ const countResponseTokens: ProxyResHandlerWithBody = async ( // seeing errors in this function, check the reassembled response body from // handleStreamedResponse to see if the upstream API has changed. try { - if (typeof body !== "object") { - throw new Error("Expected body to be an object"); - } - + assertJsonResponse(body); const service = req.outboundApi; - const { completion } = getCompletionForService({ req, service, body }); + const completion = getCompletionFromBody(req, body); const tokens = await countTokens({ req, completion, service }); req.log.debug( @@ -473,7 +512,7 @@ const countResponseTokens: ProxyResHandlerWithBody = async ( req.outputTokens = tokens.token_count; } catch (error) { - req.log.error( + req.log.warn( error, "Error while counting completion tokens; assuming `max_output_tokens`" ); @@ -505,3 +544,14 @@ const copyHttpHeaders: ProxyResHandlerWithBody = async ( res.setHeader(key, proxyRes.headers[key] as string); }); }; + +function getAwsErrorType(header: string | string[] | undefined) { + const val = String(header).match(/^(\w+):?/)?.[1]; + return val || String(header); +} + +function assertJsonResponse(body: any): asserts body is Record { + if (typeof body !== "object") { + throw new Error("Expected response to be an object"); + } +} diff --git a/src/proxy/middleware/response/log-prompt.ts b/src/proxy/middleware/response/log-prompt.ts index 08dbafa..0d634d3 100644 --- a/src/proxy/middleware/response/log-prompt.ts +++ b/src/proxy/middleware/response/log-prompt.ts @@ -1,7 +1,11 @@ import { Request } from "express"; import { config } from "../../../config"; import { logQueue } from "../../../shared/prompt-logging"; -import { getCompletionForService, isCompletionRequest } from "../common"; +import { + getCompletionFromBody, + getModelFromBody, + isCompletionRequest, +} from "../common"; import { ProxyResHandlerWithBody } from "."; import { assertNever } from "../../../shared/utils"; @@ -25,17 +29,15 @@ export const logPrompt: ProxyResHandlerWithBody = async ( const promptPayload = getPromptForRequest(req); const promptFlattened = flattenMessages(promptPayload); - const response = getCompletionForService({ - service: req.outboundApi, - body: responseBody, - }); + const response = getCompletionFromBody(req, responseBody); + const model = getModelFromBody(req, responseBody); logQueue.enqueue({ endpoint: req.inboundApi, promptRaw: JSON.stringify(promptPayload), promptFlattened, - model: response.model, // may differ from the requested model - response: response.completion, + model, + response, }); }; diff --git a/src/proxy/middleware/response/sse-stream-adapter.ts b/src/proxy/middleware/response/sse-stream-adapter.ts new file mode 100644 index 0000000..bbe473c --- /dev/null +++ b/src/proxy/middleware/response/sse-stream-adapter.ts @@ -0,0 +1,85 @@ +import { Transform, TransformOptions } from "stream"; +// @ts-ignore +import { Parser } from "lifion-aws-event-stream"; +import { logger } from "../../../logger"; + +const log = logger.child({ module: "sse-stream-adapter" }); + +type SSEStreamAdapterOptions = TransformOptions & { isAwsStream?: boolean }; +type AwsEventStreamMessage = { + headers: { ":message-type": "event" | "exception" }; + payload: { message?: string /** base64 encoded */; bytes?: string }; +}; + +/** + * Receives either text chunks or AWS binary event stream chunks and emits + * full SSE events. + */ +export class ServerSentEventStreamAdapter extends Transform { + private readonly isAwsStream; + private parser = new Parser(); + private partialMessage = ""; + + constructor(options?: SSEStreamAdapterOptions) { + super(options); + this.isAwsStream = options?.isAwsStream || false; + + this.parser.on("data", (data: AwsEventStreamMessage) => { + const message = this.processAwsEvent(data); + if (message) { + this.push(Buffer.from(message, "utf8")); + } + }); + } + + processAwsEvent(event: AwsEventStreamMessage): string | null { + const { payload, headers } = event; + if (headers[":message-type"] === "exception" || !payload.bytes) { + log.error( + { event: JSON.stringify(event) }, + "Received bad streaming event from AWS" + ); + const message = JSON.stringify(event); + return getFakeErrorCompletion("proxy AWS error", message); + } else { + return `data: ${Buffer.from(payload.bytes, "base64").toString("utf8")}`; + } + } + + _transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) { + try { + if (this.isAwsStream) { + this.parser.write(chunk); + } else { + // We may receive multiple (or partial) SSE messages in a single chunk, + // so we need to buffer and emit separate stream events for full + // messages so we can parse/transform them properly. + const str = chunk.toString("utf8"); + const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/); + this.partialMessage = fullMessages.pop() || ""; + + for (const message of fullMessages) { + this.push(message); + } + } + callback(); + } catch (error) { + this.emit("error", error); + callback(error); + } + } +} + +function getFakeErrorCompletion(type: string, message: string) { + const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`; + const fakeEvent = { + log_id: "aws-proxy-sse-message", + stop_reason: type, + completion: + "\nProxy encountered an error during streaming response.\n" + content, + truncated: false, + stop: null, + model: "", + }; + return `data: ${JSON.stringify(fakeEvent)}\n\n`; +} diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 773ccd3..21d0906 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -21,6 +21,7 @@ import { createEmbeddingsPreprocessorMiddleware, createPreprocessorMiddleware, finalizeBody, + forceModel, languageFilter, limitCompletions, stripHeaders, @@ -246,17 +247,25 @@ openaiRouter.get("/v1/models", handleModelRequest); openaiRouter.post( "/v1/completions", ipLimiter, - createPreprocessorMiddleware({ inApi: "openai-text", outApi: "openai-text" }), + createPreprocessorMiddleware({ + inApi: "openai-text", + outApi: "openai-text", + service: "openai", + }), openaiProxy ); // turbo-instruct compatibility endpoint, accepts either prompt or messages openaiRouter.post( - /\/v1\/turbo\-instruct\/(v1\/)?chat\/completions/, + /\/v1\/turbo-instruct\/(v1\/)?chat\/completions/, ipLimiter, - createPreprocessorMiddleware({ inApi: "openai", outApi: "openai-text" }, [ - rewriteForTurboInstruct, - ]), + createPreprocessorMiddleware( + { inApi: "openai", outApi: "openai-text", service: "openai" }, + { + beforeTransform: [rewriteForTurboInstruct], + afterTransform: [forceModel("gpt-3.5-turbo-instruct")], + } + ), openaiProxy ); @@ -264,7 +273,11 @@ openaiRouter.post( openaiRouter.post( "/v1/chat/completions", ipLimiter, - createPreprocessorMiddleware({ inApi: "openai", outApi: "openai" }), + createPreprocessorMiddleware({ + inApi: "openai", + outApi: "openai", + service: "openai", + }), openaiProxy ); @@ -276,18 +289,4 @@ openaiRouter.post( openaiEmbeddingsProxy ); -// Redirect browser requests to the homepage. -openaiRouter.get("*", (req, res, next) => { - const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); - if (isBrowser) { - res.redirect("/"); - } else { - next(); - } -}); -openaiRouter.use((req, res) => { - req.log.warn(`Blocked openai proxy request: ${req.method} ${req.path}`); - res.status(404).json({ error: "Not found" }); -}); - export const openai = openaiRouter; diff --git a/src/proxy/palm.ts b/src/proxy/palm.ts index ea5d393..cedf997 100644 --- a/src/proxy/palm.ts +++ b/src/proxy/palm.ts @@ -12,12 +12,13 @@ import { blockZoomerOrigins, createPreprocessorMiddleware, finalizeBody, + forceModel, languageFilter, stripHeaders, } from "./middleware/request"; import { - ProxyResHandlerWithBody, createOnProxyResHandler, + ProxyResHandlerWithBody, } from "./middleware/response"; import { v4 } from "uuid"; @@ -72,11 +73,10 @@ const rewritePalmRequest = ( // The chat api (generateMessage) is not very useful at this time as it has // few params and no adjustable safety settings. - const newProxyReqPath = proxyReq.path.replace( + proxyReq.path = proxyReq.path.replace( /^\/v1\/chat\/completions/, `/v1beta2/models/${req.body.model}:generateText` ); - proxyReq.path = newProxyReqPath; const rewriterPipeline = [ applyQuotaLimits, @@ -191,17 +191,11 @@ palmRouter.get("/v1/models", handleModelRequest); palmRouter.post( "/v1/chat/completions", ipLimiter, - createPreprocessorMiddleware({ inApi: "openai", outApi: "google-palm" }), + createPreprocessorMiddleware( + { inApi: "openai", outApi: "google-palm", service: "google-palm" }, + { afterTransform: [forceModel("text-bison-001")] } + ), googlePalmProxy ); -// Redirect browser requests to the homepage. -palmRouter.get("*", (req, res, next) => { - const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); - if (isBrowser) { - res.redirect("/"); - } else { - next(); - } -}); export const googlePalm = palmRouter; diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index de0936b..fc98ad3 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -93,7 +93,8 @@ export function enqueue(req: Request) { // If the request opted into streaming, we need to register a heartbeat // handler to keep the connection alive while it waits in the queue. We // deregister the handler when the request is dequeued. - if (req.body.stream === "true" || req.body.stream === true) { + const { stream } = req.body; + if (stream === "true" || stream === true || req.isStreaming) { const res = req.res!; if (!res.headersSent) { initStreaming(req); @@ -138,9 +139,15 @@ function getPartitionForRequest(req: Request): ModelFamily { // There is a single request queue, but it is partitioned by model family. // Model families are typically separated on cost/rate limit boundaries so // they should be treated as separate queues. - const provider = req.outboundApi; const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo"; - switch (provider) { + + // Weird special case for AWS because they serve multiple models from + // different vendors, even if currently only one is supported. + if (req.service === "aws") { + return "aws-claude"; + } + + switch (req.outboundApi) { case "anthropic": return getClaudeModelFamily(model); case "openai": @@ -149,7 +156,7 @@ function getPartitionForRequest(req: Request): ModelFamily { case "google-palm": return getGooglePalmModelFamily(model); default: - assertNever(provider); + assertNever(req.outboundApi); } } @@ -198,12 +205,13 @@ function processQueue() { // the others, because we only track one rate limit per key. // TODO: `getLockoutPeriod` uses model names instead of model families - // TODO: genericize this + // TODO: genericize this it's really ugly const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k"); const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4"); const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo"); const claudeLockout = keyPool.getLockoutPeriod("claude-v1"); const palmLockout = keyPool.getLockoutPeriod("text-bison-001"); + const awsClaudeLockout = keyPool.getLockoutPeriod("anthropic.claude-v2"); const reqs: (Request | undefined)[] = []; if (gpt432kLockout === 0) { @@ -221,6 +229,9 @@ function processQueue() { if (palmLockout === 0) { reqs.push(dequeue("bison")); } + if (awsClaudeLockout === 0) { + reqs.push(dequeue("aws-claude")); + } reqs.filter(Boolean).forEach((req) => { if (req?.proceed) { diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index b50e82b..4776f7b 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -10,6 +10,7 @@ import { checkRisuToken } from "./check-risu-token"; import { openai } from "./openai"; import { anthropic } from "./anthropic"; import { googlePalm } from "./palm"; +import { aws } from "./aws"; const proxyRouter = express.Router(); proxyRouter.use( @@ -26,4 +27,14 @@ proxyRouter.use((req, _res, next) => { proxyRouter.use("/openai", openai); proxyRouter.use("/anthropic", anthropic); proxyRouter.use("/google-palm", googlePalm); +proxyRouter.use("/aws/claude", aws); +// Redirect browser requests to the homepage. +proxyRouter.get("*", (req, res, next) => { + const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); + if (isBrowser) { + res.redirect("/"); + } else { + next(); + } +}); export { proxyRouter as proxyRouter }; diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts index 28650d0..bf1617b 100644 --- a/src/shared/key-management/anthropic/provider.ts +++ b/src/shared/key-management/anthropic/provider.ts @@ -177,10 +177,6 @@ export class AnthropicKeyProvider implements KeyProvider { return this.keys.filter((k) => !k.isDisabled).length; } - public anyUnchecked() { - return this.keys.some((k) => k.lastChecked === 0); - } - public incrementUsage(hash: string, _model: string, tokens: number) { const key = this.keys.find((k) => k.hash === hash); if (!key) return; @@ -202,10 +198,7 @@ export class AnthropicKeyProvider implements KeyProvider { // If all keys are rate-limited, return the time until the first key is // ready. - const timeUntilFirstReady = Math.min( - ...activeKeys.map((k) => k.rateLimitedUntil - now) - ); - return timeUntilFirstReady; + return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); } /** @@ -216,7 +209,7 @@ export class AnthropicKeyProvider implements KeyProvider { * retrying in order to give the other requests a chance to finish. */ public markRateLimited(keyHash: string) { - this.log.warn({ key: keyHash }, "Key rate limited"); + this.log.debug({ key: keyHash }, "Key rate limited"); const key = this.keys.find((k) => k.hash === keyHash)!; const now = Date.now(); key.rateLimitedAt = now; diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts new file mode 100644 index 0000000..efc012e --- /dev/null +++ b/src/shared/key-management/aws/provider.ts @@ -0,0 +1,180 @@ +import crypto from "crypto"; +import { Key, KeyProvider } from ".."; +import { config } from "../../../config"; +import { logger } from "../../../logger"; +import type { AwsBedrockModelFamily } from "../../models"; + +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html +export const AWS_BEDROCK_SUPPORTED_MODELS = [ + "anthropic.claude-v1", + "anthropic.claude-v2", + "anthropic.claude-instant-v1", +] as const; +export type AwsBedrockModel = (typeof AWS_BEDROCK_SUPPORTED_MODELS)[number]; + +type AwsBedrockKeyUsage = { + [K in AwsBedrockModelFamily as `${K}Tokens`]: number; +}; + +export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage { + readonly service: "aws"; + readonly modelFamilies: AwsBedrockModelFamily[]; + /** The time at which this key was last rate limited. */ + rateLimitedAt: number; + /** The time until which this key is rate limited. */ + rateLimitedUntil: number; +} + +/** + * Upon being rate limited, a key will be locked out for this many milliseconds + * while we wait for other concurrent requests to finish. + */ +const RATE_LIMIT_LOCKOUT = 300; +/** + * Upon assigning a key, we will wait this many milliseconds before allowing it + * to be used again. This is to prevent the queue from flooding a key with too + * many requests while we wait to learn whether previous ones succeeded. + */ +const KEY_REUSE_DELAY = 500; + +export class AwsBedrockKeyProvider implements KeyProvider { + readonly service = "aws"; + + private keys: AwsBedrockKey[] = []; + private log = logger.child({ module: "key-provider", service: this.service }); + + constructor() { + const keyConfig = config.awsCredentials?.trim(); + if (!keyConfig) { + this.log.warn( + "AWS_CREDENTIALS is not set. AWS Bedrock API will not be available." + ); + return; + } + let bareKeys: string[]; + bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; + for (const key of bareKeys) { + const newKey: AwsBedrockKey = { + key, + service: this.service, + modelFamilies: ["aws-claude"], + isTrial: false, + isDisabled: false, + promptCount: 0, + lastUsed: 0, + rateLimitedAt: 0, + rateLimitedUntil: 0, + hash: `aws-${crypto + .createHash("sha256") + .update(key) + .digest("hex") + .slice(0, 8)}`, + lastChecked: 0, + ["aws-claudeTokens"]: 0, + }; + this.keys.push(newKey); + } + this.log.info({ keyCount: this.keys.length }, "Loaded AWS Bedrock keys."); + } + + public init() {} + + public list() { + return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); + } + + public get(_model: AwsBedrockModel) { + const availableKeys = this.keys.filter((k) => !k.isDisabled); + if (availableKeys.length === 0) { + throw new Error("No AWS Bedrock keys available"); + } + + // (largely copied from the OpenAI provider, without trial key support) + // Select a key, from highest priority to lowest priority: + // 1. Keys which are not rate limited + // a. If all keys were rate limited recently, select the least-recently + // rate limited key. + // 3. Keys which have not been used in the longest time + + const now = Date.now(); + + const keysByPriority = availableKeys.sort((a, b) => { + const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; + const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; + + if (aRateLimited && !bRateLimited) return 1; + if (!aRateLimited && bRateLimited) return -1; + if (aRateLimited && bRateLimited) { + return a.rateLimitedAt - b.rateLimitedAt; + } + + return a.lastUsed - b.lastUsed; + }); + + const selectedKey = keysByPriority[0]; + selectedKey.lastUsed = now; + selectedKey.rateLimitedAt = now; + // Intended to throttle the queue processor as otherwise it will just + // flood the API with requests and we want to wait a sec to see if we're + // going to get a rate limit error on this key. + selectedKey.rateLimitedUntil = now + KEY_REUSE_DELAY; + return { ...selectedKey }; + } + + public disable(key: AwsBedrockKey) { + const keyFromPool = this.keys.find((k) => k.hash === key.hash); + if (!keyFromPool || keyFromPool.isDisabled) return; + keyFromPool.isDisabled = true; + this.log.warn({ key: key.hash }, "Key disabled"); + } + + public update(hash: string, update: Partial) { + const keyFromPool = this.keys.find((k) => k.hash === hash)!; + Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); + } + + public available() { + return this.keys.filter((k) => !k.isDisabled).length; + } + + public incrementUsage(hash: string, _model: string, tokens: number) { + const key = this.keys.find((k) => k.hash === hash); + if (!key) return; + key.promptCount++; + key["aws-claudeTokens"] += tokens; + } + + public getLockoutPeriod(_model: AwsBedrockModel) { + // TODO: same exact behavior for three providers, should be refactored + const activeKeys = this.keys.filter((k) => !k.isDisabled); + // Don't lock out if there are no keys available or the queue will stall. + // Just let it through so the add-key middleware can throw an error. + if (activeKeys.length === 0) return 0; + + const now = Date.now(); + const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); + const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; + + if (anyNotRateLimited) return 0; + + // If all keys are rate-limited, return time until the first key is ready. + return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); + } + + /** + * This is called when we receive a 429, which means there are already five + * concurrent requests running on this key. We don't have any information on + * when these requests will resolve, so all we can do is wait a bit and try + * again. We will lock the key for 2 seconds after getting a 429 before + * retrying in order to give the other requests a chance to finish. + */ + public markRateLimited(keyHash: string) { + this.log.debug({ key: keyHash }, "Key rate limited"); + const key = this.keys.find((k) => k.hash === keyHash)!; + const now = Date.now(); + key.rateLimitedAt = now; + key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT; + } + + public recheck() {} +} diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index d5bfbe0..de356d1 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -4,17 +4,25 @@ import { AnthropicModel, } from "./anthropic/provider"; import { GOOGLE_PALM_SUPPORTED_MODELS, GooglePalmModel } from "./palm/provider"; +import { AWS_BEDROCK_SUPPORTED_MODELS, AwsBedrockModel } from "./aws/provider"; import { KeyPool } from "./key-pool"; import type { ModelFamily } from "../models"; +/** The request and response format used by a model's API. */ export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text"; -export type Model = OpenAIModel | AnthropicModel | GooglePalmModel; +/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */ +export type LLMService = "openai" | "anthropic" | "google-palm" | "aws"; +export type Model = + | OpenAIModel + | AnthropicModel + | GooglePalmModel + | AwsBedrockModel; export interface Key { /** The API key itself. Never log this, use `hash` instead. */ readonly key: string; /** The service that this key is for. */ - service: APIFormat; + service: LLMService; /** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */ isTrial: boolean; /** The model families that this key has access to. */ @@ -44,14 +52,13 @@ for service-agnostic functionality. */ export interface KeyProvider { - readonly service: APIFormat; + readonly service: LLMService; init(): void; get(model: Model): T; list(): Omit[]; disable(key: T): void; update(hash: string, update: Partial): void; available(): number; - anyUnchecked(): boolean; incrementUsage(hash: string, model: string, tokens: number): void; getLockoutPeriod(model: Model): number; markRateLimited(hash: string): void; @@ -68,7 +75,9 @@ export { OPENAI_SUPPORTED_MODELS, ANTHROPIC_SUPPORTED_MODELS, GOOGLE_PALM_SUPPORTED_MODELS, + AWS_BEDROCK_SUPPORTED_MODELS, }; export { AnthropicKey } from "./anthropic/provider"; export { OpenAIKey } from "./openai/provider"; export { GooglePalmKey } from "./palm/provider"; +export { AwsBedrockKey } from "./aws/provider"; \ No newline at end of file diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 11eeeb6..7c330d0 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -4,16 +4,17 @@ import os from "os"; import schedule from "node-schedule"; import { config } from "../../config"; import { logger } from "../../logger"; -import { Key, Model, KeyProvider, APIFormat } from "./index"; +import { Key, Model, KeyProvider, LLMService } from "./index"; import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { GooglePalmKeyProvider } from "./palm/provider"; +import { AwsBedrockKeyProvider } from "./aws/provider"; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; export class KeyPool { private keyProviders: KeyProvider[] = []; - private recheckJobs: Partial> = { + private recheckJobs: Partial> = { openai: null, }; @@ -21,6 +22,7 @@ export class KeyPool { this.keyProviders.push(new OpenAIKeyProvider()); this.keyProviders.push(new AnthropicKeyProvider()); this.keyProviders.push(new GooglePalmKeyProvider()); + this.keyProviders.push(new AwsBedrockKeyProvider()); } public init() { @@ -28,7 +30,7 @@ export class KeyPool { const availableKeys = this.available("all"); if (availableKeys === 0) { throw new Error( - "No keys loaded. Ensure OPENAI_KEY, ANTHROPIC_KEY, or GOOGLE_PALM_KEY are set." + "No keys loaded. Ensure that at least one key is configured." ); } this.scheduleRecheck(); @@ -43,6 +45,11 @@ export class KeyPool { return this.keyProviders.flatMap((provider) => provider.list()); } + /** + * Marks a key as disabled for a specific reason. `revoked` should be used + * to indicate a key that can never be used again, while `quota` should be + * used to indicate a key that is still valid but has exceeded its quota. + */ public disable(key: Key, reason: "quota" | "revoked"): void { const service = this.getKeyProvider(key.service); service.disable(key); @@ -59,17 +66,14 @@ export class KeyPool { service.update(key.hash, props); } - public available(service: APIFormat | "all" = "all"): number { + public available(model: Model | "all" = "all"): number { return this.keyProviders.reduce((sum, provider) => { - const includeProvider = service === "all" || service === provider.service; + const includeProvider = + model === "all" || this.getService(model) === provider.service; return sum + (includeProvider ? provider.available() : 0); }, 0); } - public anyUnchecked(): boolean { - return this.keyProviders.some((provider) => provider.anyUnchecked()); - } - public incrementUsage(key: Key, model: string, tokens: number): void { const provider = this.getKeyProvider(key.service); provider.incrementUsage(key.hash, model, tokens); @@ -92,7 +96,7 @@ export class KeyPool { } } - public recheck(service: APIFormat): void { + public recheck(service: LLMService): void { if (!config.checkKeys) { logger.info("Skipping key recheck because key checking is disabled"); return; @@ -102,7 +106,7 @@ export class KeyPool { provider.recheck(); } - private getService(model: Model): APIFormat { + private getService(model: Model): LLMService { if (model.startsWith("gpt") || model.startsWith("text-embedding-ada")) { // https://platform.openai.com/docs/models/model-endpoint-compatibility return "openai"; @@ -112,16 +116,15 @@ export class KeyPool { } else if (model.includes("bison")) { // https://developers.generativeai.google.com/models/language return "google-palm"; + } else if (model.startsWith("anthropic.claude")) { + // AWS offers models from a few providers + // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html + return "aws"; } throw new Error(`Unknown service for model '${model}'`); } - private getKeyProvider(service: APIFormat): KeyProvider { - // The "openai-text" service is a special case handled by OpenAIKeyProvider. - if (service === "openai-text") { - service = "openai"; - } - + private getKeyProvider(service: LLMService): KeyProvider { return this.keyProviders.find((provider) => provider.service === service)!; } diff --git a/src/shared/key-management/openai/checker.ts b/src/shared/key-management/openai/checker.ts index e58f234..f2f397d 100644 --- a/src/shared/key-management/openai/checker.ts +++ b/src/shared/key-management/openai/checker.ts @@ -33,10 +33,10 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update; export class OpenAIKeyChecker { private readonly keys: OpenAIKey[]; - private log = logger.child({ module: "key-checker", service: "openai" }); - private timeout?: NodeJS.Timeout; private cloneKey: CloneFn; private updateKey: UpdateFn; + private log = logger.child({ module: "key-checker", service: "openai" }); + private timeout?: NodeJS.Timeout; private lastCheck = 0; constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) { @@ -248,10 +248,10 @@ export class OpenAIKeyChecker { } else if (status === 429) { switch (data.error.type) { case "insufficient_quota": - case "access_terminated": case "billing_not_active": - const isOverQuota = data.error.type === "insufficient_quota"; - const isRevoked = !isOverQuota; + case "access_terminated": + const isRevoked = data.error.type === "access_terminated"; + const isOverQuota = !isRevoked; const modelFamilies: OpenAIModelFamily[] = isRevoked ? ["turbo"] : key.modelFamilies; @@ -392,10 +392,9 @@ export class OpenAIKeyChecker { } static getHeaders(key: OpenAIKey) { - const headers = { + return { Authorization: `Bearer ${key.key}`, ...(key.organizationId && { "OpenAI-Organization": key.organizationId }), }; - return headers; } } diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 9eae686..0df2a98 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -3,11 +3,11 @@ round-robin access to keys. Keys are stored in the OPENAI_KEY environment variable as a comma-separated list of keys. */ import crypto from "crypto"; import http from "http"; -import { KeyProvider, Key, Model } from "../index"; +import { Key, KeyProvider, Model } from "../index"; import { config } from "../../../config"; import { logger } from "../../../logger"; import { OpenAIKeyChecker } from "./checker"; -import { OpenAIModelFamily, getOpenAIModelFamily } from "../../models"; +import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models"; export type OpenAIModel = | "gpt-3.5-turbo" @@ -276,10 +276,6 @@ export class OpenAIKeyProvider implements KeyProvider { return this.keys.filter((k) => !k.isDisabled).length; } - public anyUnchecked() { - return !!config.checkKeys && this.keys.some((key) => !key.lastChecked); - } - /** * Given a model, returns the period until a key will be available to service * the request, or returns 0 if a key is ready immediately. @@ -318,7 +314,7 @@ export class OpenAIKeyProvider implements KeyProvider { // If all keys are rate-limited, return the time until the first key is // ready. - const timeUntilFirstReady = Math.min( + return Math.min( ...activeKeys.map((key) => { const resetTime = Math.max( key.rateLimitRequestsReset, @@ -327,11 +323,10 @@ export class OpenAIKeyProvider implements KeyProvider { return key.rateLimitedAt + resetTime - now; }) ); - return timeUntilFirstReady; } public markRateLimited(keyHash: string) { - this.log.warn({ key: keyHash }, "Key rate limited"); + this.log.debug({ key: keyHash }, "Key rate limited"); const key = this.keys.find((k) => k.hash === keyHash)!; key.rateLimitedAt = Date.now(); } diff --git a/src/shared/key-management/palm/provider.ts b/src/shared/key-management/palm/provider.ts index b2a3a76..94627df 100644 --- a/src/shared/key-management/palm/provider.ts +++ b/src/shared/key-management/palm/provider.ts @@ -146,10 +146,6 @@ export class GooglePalmKeyProvider implements KeyProvider { return this.keys.filter((k) => !k.isDisabled).length; } - public anyUnchecked() { - return false; - } - public incrementUsage(hash: string, _model: string, tokens: number) { const key = this.keys.find((k) => k.hash === hash); if (!key) return; @@ -171,10 +167,7 @@ export class GooglePalmKeyProvider implements KeyProvider { // If all keys are rate-limited, return the time until the first key is // ready. - const timeUntilFirstReady = Math.min( - ...activeKeys.map((k) => k.rateLimitedUntil - now) - ); - return timeUntilFirstReady; + return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); } /** @@ -185,7 +178,7 @@ export class GooglePalmKeyProvider implements KeyProvider { * retrying in order to give the other requests a chance to finish. */ public markRateLimited(keyHash: string) { - this.log.warn({ key: keyHash }, "Key rate limited"); + this.log.debug({ key: keyHash }, "Key rate limited"); const key = this.keys.find((k) => k.hash === keyHash)!; const now = Date.now(); key.rateLimitedAt = now; diff --git a/src/shared/models.ts b/src/shared/models.ts index 1dc15d3..07af27b 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -3,14 +3,23 @@ import { logger } from "../logger"; export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k"; export type AnthropicModelFamily = "claude"; export type GooglePalmModelFamily = "bison"; +export type AwsBedrockModelFamily = "aws-claude"; export type ModelFamily = | OpenAIModelFamily | AnthropicModelFamily - | GooglePalmModelFamily; + | GooglePalmModelFamily + | AwsBedrockModelFamily; export const MODEL_FAMILIES = (( arr: A & ([ModelFamily] extends [A[number]] ? unknown : never) -) => arr)(["turbo", "gpt4", "gpt4-32k", "claude", "bison"] as const); +) => arr)([ + "turbo", + "gpt4", + "gpt4-32k", + "claude", + "bison", + "aws-claude", +] as const); export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { "^gpt-4-32k-\\d{4}$": "gpt4-32k", @@ -41,6 +50,10 @@ export function getGooglePalmModelFamily(model: string): ModelFamily { return "bison"; } +export function getAwsBedrockModelFamily(_model: string): ModelFamily { + return "aws-claude"; +} + export function assertIsKnownModelFamily( modelFamily: string ): asserts modelFamily is ModelFamily { diff --git a/src/shared/tokenization/tokenizer.ts b/src/shared/tokenization/tokenizer.ts index 121ad43..6b5491c 100644 --- a/src/shared/tokenization/tokenizer.ts +++ b/src/shared/tokenization/tokenizer.ts @@ -1,5 +1,4 @@ import { Request } from "express"; -import { config } from "../../config"; import { assertNever } from "../utils"; import { init as initClaude, @@ -13,12 +12,8 @@ import { import { APIFormat } from "../key-management"; export async function init() { - if (config.anthropicKey) { - initClaude(); - } - if (config.openaiKey || config.googlePalmKey) { - initOpenAi(); - } + initClaude(); + initOpenAi(); } /** Tagged union via `service` field of the different types of requests that can diff --git a/src/shared/users/schema.ts b/src/shared/users/schema.ts index 1d6efe2..cefcae1 100644 --- a/src/shared/users/schema.ts +++ b/src/shared/users/schema.ts @@ -8,6 +8,7 @@ export const tokenCountsSchema: ZodType = z.object({ "gpt4-32k": z.number().optional().default(0), claude: z.number().optional().default(0), bison: z.number().optional().default(0), + "aws-claude": z.number().optional().default(0), }); export const UserSchema = z diff --git a/src/shared/users/user-store.ts b/src/shared/users/user-store.ts index 25ba282..3af078b 100644 --- a/src/shared/users/user-store.ts +++ b/src/shared/users/user-store.ts @@ -11,7 +11,7 @@ import admin from "firebase-admin"; import schedule from "node-schedule"; import { v4 as uuid } from "uuid"; import { config, getFirebaseApp } from "../../config"; -import { ModelFamily } from "../models"; +import { MODEL_FAMILIES, ModelFamily } from "../models"; import { logger } from "../../logger"; import { User, UserTokenCounts, UserUpdate } from "./schema"; @@ -23,6 +23,7 @@ const INITIAL_TOKENS: Required = { "gpt4-32k": 0, claude: 0, bison: 0, + "aws-claude": 0, }; const users: Map = new Map(); @@ -131,12 +132,14 @@ export function upsertUser(user: UserUpdate) { // TODO: Write firebase migration to backfill new fields if (updates.tokenCounts) { - updates.tokenCounts["gpt4-32k"] ??= 0; - updates.tokenCounts["bison"] ??= 0; + for (const family of MODEL_FAMILIES) { + updates.tokenCounts[family] ??= 0; + } } if (updates.tokenLimits) { - updates.tokenLimits["gpt4-32k"] ??= 0; - updates.tokenLimits["bison"] ??= 0; + for (const family of MODEL_FAMILIES) { + updates.tokenLimits[family] ??= 0; + } } users.set(user.token, Object.assign(existing, updates)); @@ -360,9 +363,12 @@ function getModelFamilyForQuotaUsage(model: string): ModelFamily { if (model.includes("bison")) { return "bison"; } - if (model.includes("claude")) { + if (model.startsWith("claude")) { return "claude"; } + if(model.startsWith("anthropic.claude")) { + return "aws-claude"; + } throw new Error(`Unknown quota model family for model ${model}`); } diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index 2c9f4ac..6326e9f 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -1,11 +1,13 @@ +import type { HttpRequest } from "@smithy/types"; import { Express } from "express-serve-static-core"; -import { APIFormat, Key } from "../shared/key-management/index"; +import { APIFormat, Key, LLMService } from "../shared/key-management"; import { User } from "../shared/users/user-store"; declare global { namespace Express { interface Request { key?: Key; + service?: LLMService; /** Denotes the format of the user's submitted request. */ inboundApi: APIFormat; /** Denotes the format of the request being proxied to the API. */ @@ -24,6 +26,7 @@ declare global { outputTokens?: number; // TODO: remove later debug: Record; + signedRequest: HttpRequest; } } }