fix(server): avoid errors for very small top_p values (#544)
See https://github.com/huggingface/transformers/pull/24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers.
This commit is contained in:
parent
2a101207d4
commit
e4b26aa10b
|
@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
|||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = probs <= self.top_p_opposite
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
|
|
Loading…
Reference in New Issue