fix: improve completions to send a final chunk with usage details (#2336)

* fix: improve completions to send a final chunk with usage details

* fix: include finish reason string

* fix: remove dev debug trait and unneeded mut

* fix: update openapi schema
This commit is contained in:
drbh 2024-08-12 11:26:11 -04:00 committed by GitHub
parent 4c3f8a70a1
commit 30395b09f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 9 deletions

View File

@ -1824,7 +1824,8 @@
"type": "object", "type": "object",
"required": [ "required": [
"finish_reason", "finish_reason",
"generated_tokens" "generated_tokens",
"input_length"
], ],
"properties": { "properties": {
"finish_reason": { "finish_reason": {
@ -1836,6 +1837,12 @@
"example": 1, "example": 1,
"minimum": 0 "minimum": 0
}, },
"input_length": {
"type": "integer",
"format": "int32",
"example": 1,
"minimum": 0
},
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",

View File

@ -1219,6 +1219,8 @@ pub(crate) struct StreamDetails {
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(nullable = true, example = 42)] #[schema(nullable = true, example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
#[schema(example = 1)]
pub input_length: u32,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]

View File

@ -533,7 +533,7 @@ async fn generate_stream_internal(
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, _input_length, response_stream)) => { Ok((_permit, input_length, response_stream)) => {
let mut index = 0; let mut index = 0;
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream // Server-Sent Event stream
@ -576,6 +576,7 @@ async fn generate_stream_internal(
finish_reason: generated_text.finish_reason, finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens, generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed, seed: generated_text.seed,
input_length,
}), }),
false => None, false => None,
}; };
@ -801,21 +802,46 @@ async fn completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
event let message = match stream_token.details {
.json_data(Completion::Chunk(Chunk { Some(details) => {
id: "".to_string(), let completion_tokens = details.generated_tokens;
created: current_time, let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Completion::Final(CompletionFinal {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
})
}
None => Completion::Chunk(Chunk {
id: String::new(),
created: current_time,
choices: vec![CompletionComplete { choices: vec![CompletionComplete {
finish_reason: "".to_string(), finish_reason: String::new(),
index: index as u32, index: index as u32,
logprobs: None, logprobs: None,
text: stream_token.token.text, text: stream_token.token.text,
}], }],
model: model_id.clone(), model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(), system_fingerprint: system_fingerprint.clone(),
})) }),
};
event
.json_data(message)
.unwrap_or_else(|_e| Event::default()) .unwrap_or_else(|_e| Event::default())
}; };