re-signs AWS requests on every attempt to fix fucked up queueing

This commit is contained in:
nai-degen 2023-10-24 13:06:10 -05:00
parent 26dc79c8f1
commit 89e1ed46d5
12 changed files with 94 additions and 47 deletions

View File

@ -13,7 +13,8 @@ import {
createPreprocessorMiddleware,
finalizeBody,
languageFilter,
stripHeaders, createOnProxyReqHandler
stripHeaders,
createOnProxyReqHandler,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
@ -129,8 +130,8 @@ function transformAnthropicResponse(
};
}
const anthropicProxy = createQueueMiddleware(
createProxyMiddleware({
const anthropicProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.anthropic.com",
changeOrigin: true,
selfHandleResponse: true,
@ -154,8 +155,8 @@ const anthropicProxy = createQueueMiddleware(
// Send OpenAI-compat requests to the real Anthropic endpoint.
"^/v1/chat/completions": "/v1/complete",
},
})
);
}),
});
const anthropicRouter = Router();
anthropicRouter.get("/v1/models", handleModelRequest);

View File

@ -3,7 +3,6 @@ import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { keyPool } from "../shared/key-management";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
@ -120,13 +119,12 @@ function transformAwsResponse(
};
}
const awsProxy = createQueueMiddleware(
createProxyMiddleware({
const awsProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) {
throw new Error("AWS requests must go through signAwsRequest first");
}
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
@ -135,9 +133,7 @@ const awsProxy = createQueueMiddleware(
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
(_, req) => keyPool.throttle(req.key!),
applyQuotaLimits,
// Credentials are added by signAwsRequest preprocessor
languageFilter,
blockZoomerOrigins,
stripHeaders,
@ -147,8 +143,8 @@ const awsProxy = createQueueMiddleware(
proxyRes: createOnProxyResHandler([awsResponseHandler]),
error: handleProxyError,
},
})
);
}),
});
const awsRouter = Router();
awsRouter.get("/v1/models", handleModelRequest);
@ -158,7 +154,7 @@ awsRouter.post(
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "anthropic", outApi: "anthropic", service: "aws" },
{ afterTransform: [maybeReassignModel, signAwsRequest] }
{ afterTransform: [maybeReassignModel] }
),
awsProxy
);
@ -168,7 +164,7 @@ awsRouter.post(
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic", service: "aws" },
{ afterTransform: [maybeReassignModel, signAwsRequest] }
{ afterTransform: [maybeReassignModel] }
),
awsProxy
);

View File

@ -59,7 +59,6 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
}
}
keyPool.throttle(assignedKey);
req.key = assignedKey;
req.log.info(
{
@ -117,7 +116,7 @@ export const addKeyForEmbeddingsRequest: ProxyRequestMiddleware = (
throw new Error("Embeddings requests must be from OpenAI");
}
req.body = { input: req.body.input, model: "text-embedding-ada-002" }
req.body = { input: req.body.input, model: "text-embedding-ada-002" };
const key = keyPool.get("text-embedding-ada-002") as OpenAIKey;

View File

@ -25,6 +25,7 @@ import {
limitCompletions,
stripHeaders,
createOnProxyReqHandler,
signAwsRequest,
} from "./middleware/request";
import {
createOnProxyResHandler,
@ -163,8 +164,8 @@ function transformTurboInstructResponse(
return transformed;
}
const openaiProxy = createQueueMiddleware(
createProxyMiddleware({
const openaiProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.openai.com",
changeOrigin: true,
selfHandleResponse: true,
@ -184,8 +185,8 @@ const openaiProxy = createQueueMiddleware(
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
error: handleProxyError,
},
})
);
}),
});
const openaiEmbeddingsProxy = createProxyMiddleware({
target: "https://api.openai.com",

View File

@ -143,8 +143,8 @@ function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
);
}
const googlePalmProxy = createQueueMiddleware(
createProxyMiddleware({
const googlePalmProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://generativelanguage.googleapis.com",
changeOrigin: true,
selfHandleResponse: true,
@ -164,8 +164,8 @@ const googlePalmProxy = createQueueMiddleware(
proxyRes: createOnProxyResHandler([palmResponseHandler]),
error: handleProxyError,
},
})
);
}),
});
const palmRouter = Router();
palmRouter.get("/v1/models", handleModelRequest);

View File

@ -23,6 +23,7 @@ import { buildFakeSse, initializeSseStream } from "../shared/streaming";
import { assertNever } from "../shared/utils";
import { logger } from "../logger";
import { SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
@ -52,7 +53,7 @@ function getIdentifier(req: Request) {
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
getIdentifier(queued) === getIdentifier(incoming);
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip)
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
export function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
@ -325,9 +326,23 @@ export function getQueueLength(partition: ModelFamily | "all" = "all") {
return modelQueue.length;
}
export function createQueueMiddleware(proxyMiddleware: Handler): Handler {
export function createQueueMiddleware({
beforeProxy,
proxyMiddleware,
}: {
beforeProxy?: RequestPreprocessor;
proxyMiddleware: Handler;
}): Handler {
return (req, res, next) => {
req.proceed = () => {
req.proceed = async () => {
if (beforeProxy) {
// Hack to let us run asynchronous middleware before the
// http-proxy-middleware handler. This is used to sign AWS requests
// before they are proxied, as the signing is asynchronous.
// Unlike RequestPreprocessors, this runs every time the request is
// dequeued, not just the first time.
await beforeProxy(req);
}
proxyMiddleware(req, res, next);
};

View File

@ -153,6 +153,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@ -222,10 +223,19 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
this.checker?.scheduleNextCheck();
}
public throttle(hash: string) {
const key = this.keys.find((k) => k.hash === hash)!;
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = now + KEY_REUSE_DELAY;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}

View File

@ -131,6 +131,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@ -195,10 +196,19 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
);
}
public throttle(hash: string) {
const key = this.keys.find((k) => k.hash === hash)!;
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = now + KEY_REUSE_DELAY;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}

View File

@ -63,7 +63,6 @@ export interface KeyProvider<T extends Key = Key> {
getLockoutPeriod(model: Model): number;
markRateLimited(hash: string): void;
recheck(): void;
throttle(hash: string): void;
}
export const keyPool = new KeyPool();

View File

@ -72,11 +72,6 @@ export class KeyPool {
}, 0);
}
public throttle(key: Key) {
const provider = this.getKeyProvider(key.service);
provider.throttle(key.hash);
}
public incrementUsage(key: Key, model: string, tokens: number): void {
const provider = this.getKeyProvider(key.service);
provider.incrementUsage(key.hash, model, tokens);

View File

@ -221,6 +221,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@ -228,7 +229,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
public update(keyHash: string, update: OpenAIKeyUpdate) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
// this.writeKeyStatus();
}
/** Called by the key checker to create clones of keys for the given orgs. */
@ -379,8 +379,19 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
* avoid spamming the API with requests while we wait to learn whether this
* key is already rate limited.
*/
public throttle(hash: string) {
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = Math.max(
key.rateLimitRequestsReset,
key.rateLimitTokensReset
) + key.rateLimitedAt;
const nextRateLimit = now + KEY_REUSE_DELAY;
// Don't throttle if the key is already naturally rate limited.
if (currentRateLimit > nextRateLimit) return;
key.rateLimitedAt = Date.now();
key.rateLimitRequestsReset = KEY_REUSE_DELAY;
}

View File

@ -122,6 +122,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@ -182,10 +183,19 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
public recheck() {}
public throttle(hash: string) {
const key = this.keys.find((k) => k.hash === hash)!;
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = now + KEY_REUSE_DELAY;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}