diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 0cbbf8b0..f424eae4 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = probs <= self.top_p_opposite - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(