diff --git a/src/proxy/middleware/request/onproxyreq/add-key.ts b/src/proxy/middleware/request/onproxyreq/add-key.ts index 03c2385..4a84aed 100644 --- a/src/proxy/middleware/request/onproxyreq/add-key.ts +++ b/src/proxy/middleware/request/onproxyreq/add-key.ts @@ -23,16 +23,16 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => { } if (req.inboundApi === req.outboundApi) { - assignedKey = keyPool.get(req.body.model); + assignedKey = keyPool.get(req.body.model, req.service); } else { switch (req.outboundApi) { // If we are translating between API formats we may need to select a model // for the user, because the provided model is for the inbound API. case "anthropic": - assignedKey = keyPool.get("claude-v1"); + assignedKey = keyPool.get("claude-v1", req.service); break; case "openai-text": - assignedKey = keyPool.get("gpt-3.5-turbo-instruct"); + assignedKey = keyPool.get("gpt-3.5-turbo-instruct", req.service); break; case "openai": throw new Error( @@ -43,7 +43,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => { case "mistral-ai": throw new Error("Mistral AI should never be translated"); case "openai-image": - assignedKey = keyPool.get("dall-e-3"); + assignedKey = keyPool.get("dall-e-3", req.service); break; default: assertNever(req.outboundApi); @@ -106,7 +106,7 @@ export const addKeyForEmbeddingsRequest: HPMRequestCallback = ( req.body = { input: req.body.input, model: "text-embedding-ada-002" }; - const key = keyPool.get("text-embedding-ada-002") as OpenAIKey; + const key = keyPool.get("text-embedding-ada-002", "openai") as OpenAIKey; req.key = key; req.log.info( diff --git a/src/proxy/middleware/request/preprocessors/add-azure-key.ts b/src/proxy/middleware/request/preprocessors/add-azure-key.ts index b742c5a..3656441 100644 --- a/src/proxy/middleware/request/preprocessors/add-azure-key.ts +++ b/src/proxy/middleware/request/preprocessors/add-azure-key.ts @@ -16,7 +16,7 @@ export const addAzureKey: RequestPreprocessor = (req) => { ? req.body.model : `azure-${req.body.model}`; - req.key = keyPool.get(model); + req.key = keyPool.get(model, "azure"); req.body.model = model; // Handles the sole Azure API deviation from the OpenAI spec (that I know of) diff --git a/src/proxy/middleware/request/preprocessors/add-google-ai-key.ts b/src/proxy/middleware/request/preprocessors/add-google-ai-key.ts index 439d55f..56807fa 100644 --- a/src/proxy/middleware/request/preprocessors/add-google-ai-key.ts +++ b/src/proxy/middleware/request/preprocessors/add-google-ai-key.ts @@ -13,7 +13,7 @@ export const addGoogleAIKey: RequestPreprocessor = (req) => { } const model = req.body.model; - req.key = keyPool.get(model); + req.key = keyPool.get(model, "google-ai"); req.log.info( { key: req.key.hash, model }, diff --git a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts index b4a04b4..781c639 100644 --- a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts +++ b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts @@ -14,7 +14,7 @@ const AMZ_HOST = * request object in place to fix the path. */ export const signAwsRequest: RequestPreprocessor = async (req) => { - req.key = keyPool.get("anthropic.claude-v2"); + req.key = keyPool.get("anthropic.claude-v2", "aws"); const { model, stream } = req.body; req.isStreaming = stream === true || stream === "true"; diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index dc299c1..db0b566 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -41,9 +41,9 @@ export class KeyPool { this.scheduleRecheck(); } - public get(model: Model): Key { - const service = this.getServiceForModel(model); - return this.getKeyProvider(service).get(model); + public get(model: Model, service?: LLMService): Key { + const queryService = service || this.getServiceForModel(model); + return this.getKeyProvider(queryService).get(model); } public list(): Omit[] {