From 30395b09f4eff271dd1dfdc49be4fd46f4a546dd Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 11:26:11 -0400 Subject: [PATCH] fix: improve completions to send a final chunk with usage details (#2336) * fix: improve completions to send a final chunk with usage details * fix: include finish reason string * fix: remove dev debug trait and unneeded mut * fix: update openapi schema --- docs/openapi.json | 9 ++++++++- router/src/lib.rs | 2 ++ router/src/server.rs | 42 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index ecd56e4d..df21e19d 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1824,7 +1824,8 @@ "type": "object", "required": [ "finish_reason", - "generated_tokens" + "generated_tokens", + "input_length" ], "properties": { "finish_reason": { @@ -1836,6 +1837,12 @@ "example": 1, "minimum": 0 }, + "input_length": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, "seed": { "type": "integer", "format": "int64", diff --git a/router/src/lib.rs b/router/src/lib.rs index 0a15c495..d7eb4475 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1219,6 +1219,8 @@ pub(crate) struct StreamDetails { pub generated_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, + #[schema(example = 1)] + pub input_length: u32, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 99ec077f..ab268efa 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -533,7 +533,7 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, response_stream)) => { + Ok((_permit, input_length, response_stream)) => { let mut index = 0; let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream @@ -576,6 +576,7 @@ async fn generate_stream_internal( finish_reason: generated_text.finish_reason, generated_tokens: generated_text.generated_tokens, seed: generated_text.seed, + input_length, }), false => None, }; @@ -801,21 +802,46 @@ async fn completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - event - .json_data(Completion::Chunk(Chunk { - id: "".to_string(), - created: current_time, + let message = match stream_token.details { + Some(details) => { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + Completion::Final(CompletionFinal { + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }) + } + None => Completion::Chunk(Chunk { + id: String::new(), + created: current_time, choices: vec![CompletionComplete { - finish_reason: "".to_string(), + finish_reason: String::new(), index: index as u32, logprobs: None, text: stream_token.token.text, }], - model: model_id.clone(), system_fingerprint: system_fingerprint.clone(), - })) + }), + }; + + event + .json_data(message) .unwrap_or_else(|_e| Event::default()) };