fix(server): avoid errors for very small top_p values (#544)

See https://github.com/huggingface/transformers/pull/24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.
This commit is contained in:
Nick Hill 2023-07-04 11:11:33 -07:00 committed by GitHub
parent 2a101207d4
commit e4b26aa10b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 3 deletions

View File

@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(