From 595640e35cb239a505792a0e9586559750f89633 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 15 Oct 2024 15:08:23 +0000 Subject: [PATCH] fix: enforce default max request tokens in generate_internal --- router/src/lib.rs | 3 +-- router/src/server.rs | 6 +++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index b29c9395..c25ab10b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -893,7 +893,6 @@ impl ChatRequest { } = self; let repetition_penalty = presence_penalty.map(|x| x + 2.0); - let max_new_tokens = max_tokens.or(Some(100)); let tool_prompt = tool_prompt .filter(|s| !s.is_empty()) .unwrap_or_else(default_tool_prompt); @@ -926,7 +925,7 @@ impl ChatRequest { top_p, typical_p: None, do_sample, - max_new_tokens, + max_new_tokens: max_tokens, return_full_text: None, stop, truncate: None, diff --git a/router/src/server.rs b/router/src/server.rs index 5e6e6960..e82716fc 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -260,7 +260,7 @@ async fn generate( pub(crate) async fn generate_internal( infer: Extension, ComputeType(compute_type): ComputeType, - Json(req): Json, + Json(mut req): Json, span: tracing::Span, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let start_time = Instant::now(); @@ -278,6 +278,10 @@ pub(crate) async fn generate_internal( add_prompt = Some(req.inputs.clone()); } + if req.parameters.max_new_tokens.is_none() { + req.parameters.max_new_tokens = Some(100); + } + let details: bool = req.parameters.details || req.parameters.decoder_input_details; // Inference