fix: improve completions to send a final chunk with usage details (#2336)

* fix: improve completions to send a final chunk with usage details

* fix: include finish reason string

* fix: remove dev debug trait and unneeded mut

* fix: update openapi schema
This commit is contained in:
drbh 2024-08-12 11:26:11 -04:00 committed by GitHub
parent 4c3f8a70a1
commit 30395b09f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 9 deletions

View File

@ -1824,7 +1824,8 @@
"type": "object",
"required": [
"finish_reason",
"generated_tokens"
"generated_tokens",
"input_length"
],
"properties": {
"finish_reason": {
@ -1836,6 +1837,12 @@
"example": 1,
"minimum": 0
},
"input_length": {
"type": "integer",
"format": "int32",
"example": 1,
"minimum": 0
},
"seed": {
"type": "integer",
"format": "int64",

View File

@ -1219,6 +1219,8 @@ pub(crate) struct StreamDetails {
pub generated_tokens: u32,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
#[schema(example = 1)]
pub input_length: u32,
}
#[derive(Serialize, ToSchema)]

View File

@ -533,7 +533,7 @@ async fn generate_stream_internal(
} else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, _input_length, response_stream)) => {
Ok((_permit, input_length, response_stream)) => {
let mut index = 0;
let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream
@ -576,6 +576,7 @@ async fn generate_stream_internal(
finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
input_length,
}),
false => None,
};
@ -801,21 +802,46 @@ async fn completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
event
.json_data(Completion::Chunk(Chunk {
id: "".to_string(),
created: current_time,
let message = match stream_token.details {
Some(details) => {
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Completion::Final(CompletionFinal {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![CompletionComplete {
finish_reason: "".to_string(),
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
})
}
None => Completion::Chunk(Chunk {
id: String::new(),
created: current_time,
choices: vec![CompletionComplete {
finish_reason: String::new(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
}))
}),
};
event
.json_data(message)
.unwrap_or_else(|_e| Event::default())
};