fix(server): avoid errors for very small top_p values (#544)
See https://github.com/huggingface/transformers/pull/24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers.
This commit is contained in:
parent
2a101207d4
commit
e4b26aa10b
|
@ -189,7 +189,6 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||||
|
|
||||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
sorted_indices_to_remove = probs <= self.top_p_opposite
|
sorted_indices_to_remove = probs <= self.top_p_opposite
|
||||||
if self.min_tokens_to_keep > 1:
|
|
||||||
# Keep at least min_tokens_to_keep
|
# Keep at least min_tokens_to_keep
|
||||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue