feat: prefer stop over eos_token to align with openai finish_reason (#2344)

This commit is contained in:
drbh 2024-08-06 13:09:50 -04:00 committed by GitHub
parent e11f5f1c38
commit f8a5b381fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 3 deletions

View File

@ -619,7 +619,7 @@ impl ChatCompletion {
message, message,
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.format(true),
}], }],
usage: Usage { usage: Usage {
prompt_tokens: details.prefill.len() as u32, prompt_tokens: details.prefill.len() as u32,
@ -1117,6 +1117,15 @@ impl std::fmt::Display for FinishReason {
} }
} }
impl FinishReason {
pub fn format(&self, use_stop: bool) -> String {
match self {
FinishReason::EndOfSequenceToken if use_stop => "stop".to_string(),
_ => self.to_string(),
}
}
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence { pub(crate) struct BestOfSequence {
#[schema(example = "test")] #[schema(example = "test")]

View File

@ -1021,7 +1021,7 @@ async fn completions(
total_tokens += details.prefill.len() as u32 + details.generated_tokens; total_tokens += details.prefill.len() as u32 + details.generated_tokens;
Ok(CompletionComplete { Ok(CompletionComplete {
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.format(true),
index: index as u32, index: index as u32,
logprobs: None, logprobs: None,
text: generation.generated_text, text: generation.generated_text,
@ -1212,7 +1212,7 @@ async fn chat_completions(
tool_calls, tool_calls,
current_time, current_time,
logprobs, logprobs,
stream_token.details.map(|d| d.finish_reason.to_string()), stream_token.details.map(|d| d.finish_reason.format(true)),
), ),
)) ))
.unwrap_or_else(|e| { .unwrap_or_else(|e| {