From e4b26aa10bd43c93cd236a9e3388692eb1e8a321 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 4 Jul 2023 11:11:33 -0700 Subject: [PATCH] fix(server): avoid errors for very small top_p values (#544) See https://github.com/huggingface/transformers/pull/24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers. --- server/text_generation_server/utils/logits_process.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 0cbbf8b0..f424eae4 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = probs <= self.top_p_opposite - if self.min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(