diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cec9ae55..696f0fb2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -579,7 +579,7 @@ class CausalLM(Model): batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, - torch.softmax(logits[:, -1], -1), + torch.log_softmax(logits[:, -1], -1), ) # Zipped iterator diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 1a7911ac..34932c0b 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -642,7 +642,7 @@ class Seq2SeqLM(Model): batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, - torch.softmax(logits[:, -1], -1), + torch.log_softmax(logits[:, -1], -1), ) # Finished requests