fix: enforce default max request tokens in generate_internal

This commit is contained in:
David Holtz 2024-10-15 15:08:23 +00:00
parent ce7e356561
commit 595640e35c
2 changed files with 6 additions and 3 deletions

View File

@ -893,7 +893,6 @@ impl ChatRequest {
} = self;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
@ -926,7 +925,7 @@ impl ChatRequest {
top_p,
typical_p: None,
do_sample,
max_new_tokens,
max_new_tokens: max_tokens,
return_full_text: None,
stop,
truncate: None,

View File

@ -260,7 +260,7 @@ async fn generate(
pub(crate) async fn generate_internal(
infer: Extension<Infer>,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
Json(mut req): Json<GenerateRequest>,
span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now();
@ -278,6 +278,10 @@ pub(crate) async fn generate_internal(
add_prompt = Some(req.inputs.clone());
}
if req.parameters.max_new_tokens.is_none() {
req.parameters.max_new_tokens = Some(100);
}
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Inference