fix: enforce default max request tokens in generate_internal

This commit is contained in:
David Holtz 2024-10-15 15:08:23 +00:00
parent ce7e356561
commit 595640e35c
2 changed files with 6 additions and 3 deletions

View File

@ -893,7 +893,6 @@ impl ChatRequest {
} = self; } = self;
let repetition_penalty = presence_penalty.map(|x| x + 2.0); let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt let tool_prompt = tool_prompt
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt); .unwrap_or_else(default_tool_prompt);
@ -926,7 +925,7 @@ impl ChatRequest {
top_p, top_p,
typical_p: None, typical_p: None,
do_sample, do_sample,
max_new_tokens, max_new_tokens: max_tokens,
return_full_text: None, return_full_text: None,
stop, stop,
truncate: None, truncate: None,

View File

@ -260,7 +260,7 @@ async fn generate(
pub(crate) async fn generate_internal( pub(crate) async fn generate_internal(
infer: Extension<Infer>, infer: Extension<Infer>,
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(mut req): Json<GenerateRequest>,
span: tracing::Span, span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
@ -278,6 +278,10 @@ pub(crate) async fn generate_internal(
add_prompt = Some(req.inputs.clone()); add_prompt = Some(req.inputs.clone());
} }
if req.parameters.max_new_tokens.is_none() {
req.parameters.max_new_tokens = Some(100);
}
let details: bool = req.parameters.details || req.parameters.decoder_input_details; let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Inference // Inference