95 lines
3.1 KiB
TypeScript
95 lines
3.1 KiB
TypeScript
import { Request } from "express";
|
|
import { z } from "zod";
|
|
import { config } from "../../../config";
|
|
import { assertNever } from "../../../shared/utils";
|
|
import { RequestPreprocessor } from ".";
|
|
|
|
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
|
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
|
const BISON_MAX_CONTEXT = 8100;
|
|
|
|
/**
|
|
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
|
|
* and outbound API format, which combined determine the size of the context.
|
|
* If the context is too large, an error is thrown.
|
|
* This preprocessor should run after any preprocessor that transforms the
|
|
* request body.
|
|
*/
|
|
export const validateContextSize: RequestPreprocessor = async (req) => {
|
|
assertRequestHasTokenCounts(req);
|
|
const promptTokens = req.promptTokens;
|
|
const outputTokens = req.outputTokens;
|
|
const contextTokens = promptTokens + outputTokens;
|
|
const model = req.body.model;
|
|
|
|
let proxyMax: number;
|
|
switch (req.outboundApi) {
|
|
case "openai":
|
|
case "openai-text":
|
|
proxyMax = OPENAI_MAX_CONTEXT;
|
|
break;
|
|
case "anthropic":
|
|
proxyMax = CLAUDE_MAX_CONTEXT;
|
|
break;
|
|
case "google-palm":
|
|
proxyMax = BISON_MAX_CONTEXT;
|
|
break;
|
|
default:
|
|
assertNever(req.outboundApi);
|
|
}
|
|
proxyMax ||= Number.MAX_SAFE_INTEGER;
|
|
|
|
let modelMax = 0;
|
|
if (model.match(/gpt-3.5-turbo-16k/)) {
|
|
modelMax = 16384;
|
|
} else if (model.match(/gpt-3.5-turbo/)) {
|
|
modelMax = 4096;
|
|
} else if (model.match(/gpt-4-32k/)) {
|
|
modelMax = 32768;
|
|
} else if (model.match(/gpt-4/)) {
|
|
modelMax = 8192;
|
|
} else if (model.match(/claude-(?:instant-)?v1(?:\.\d)?(?:-100k)/)) {
|
|
modelMax = 100000;
|
|
} else if (model.match(/claude-(?:instant-)?v1(?:\.\d)?$/)) {
|
|
modelMax = 9000;
|
|
} else if (model.match(/claude-2/)) {
|
|
modelMax = 100000;
|
|
} else if (model.match(/^text-bison-\d{3}$/)) {
|
|
modelMax = BISON_MAX_CONTEXT;
|
|
} else {
|
|
// Don't really want to throw here because I don't want to have to update
|
|
// this ASAP every time a new model is released.
|
|
req.log.warn({ model }, "Unknown model, using 100k token limit.");
|
|
modelMax = 100000;
|
|
}
|
|
|
|
const finalMax = Math.min(proxyMax, modelMax);
|
|
z.number()
|
|
.int()
|
|
.max(finalMax, {
|
|
message: `Your request exceeds the context size limit for this model or proxy. (max: ${finalMax} tokens, requested: ${promptTokens} prompt + ${outputTokens} output = ${contextTokens} context tokens)`,
|
|
})
|
|
.parse(contextTokens);
|
|
|
|
req.log.debug(
|
|
{ promptTokens, outputTokens, contextTokens, modelMax, proxyMax },
|
|
"Prompt size validated"
|
|
);
|
|
|
|
req.debug.prompt_tokens = promptTokens;
|
|
req.debug.completion_tokens = outputTokens;
|
|
req.debug.max_model_tokens = modelMax;
|
|
req.debug.max_proxy_tokens = proxyMax;
|
|
};
|
|
|
|
function assertRequestHasTokenCounts(
|
|
req: Request
|
|
): asserts req is Request & { promptTokens: number; outputTokens: number } {
|
|
z.object({
|
|
promptTokens: z.number().int().min(1),
|
|
outputTokens: z.number().int().min(1),
|
|
})
|
|
.nonstrict()
|
|
.parse({ promptTokens: req.promptTokens, outputTokens: req.outputTokens });
|
|
}
|