fix: update formatting via pre-commit

This commit is contained in:
drbh 2024-04-29 16:40:28 +00:00
parent 27773963cd
commit f2d760ff1a
1 changed files with 3 additions and 2 deletions

View File

@ -146,7 +146,6 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
# set score to 0 where input_ids is a padding token # set score to 0 where input_ids is a padding token
score *= input_ids.ne(0) score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
@ -172,7 +171,9 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
# Calculate the frequency for each token so far # Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float)) token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
)
token_freq /= input_size token_freq /= input_size
# Apply the frequency penalty to logits # Apply the frequency penalty to logits