fix(server): avoid errors for very small top_p values (#544)
See https://github.com/huggingface/transformers/pull/24111 I didn't add validation to the `__init__` method since it's not done for other values/warpers.
This commit is contained in:
parent
2a101207d4
commit
e4b26aa10b
|
@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||||
|
|
||||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
sorted_indices_to_remove = probs <= self.top_p_opposite
|
sorted_indices_to_remove = probs <= self.top_p_opposite
|
||||||
if self.min_tokens_to_keep > 1:
|
# Keep at least min_tokens_to_keep
|
||||||
# Keep at least min_tokens_to_keep
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
# scatter sorted tensors to original indexing
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
|
|
Loading…
Reference in New Issue