fix: update formatting via pre-commit
This commit is contained in:
parent
27773963cd
commit
f2d760ff1a
|
@ -146,7 +146,6 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||
# set score to 0 where input_ids is a padding token
|
||||
score *= input_ids.ne(0)
|
||||
|
||||
|
||||
return scores.scatter_add_(1, input_ids, score)
|
||||
|
||||
|
||||
|
@ -172,7 +171,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||
|
||||
# Calculate the frequency for each token so far
|
||||
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
|
||||
token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float))
|
||||
token_freq.scatter_add_(
|
||||
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
|
||||
)
|
||||
token_freq /= input_size
|
||||
|
||||
# Apply the frequency penalty to logits
|
||||
|
|
Loading…
Reference in New Issue