Merge commit 'refs/pull/1811/head' of github.com:huggingface/text-generation-inference into martinigoyanes-fix-frequency-penalty
This commit is contained in:
commit
27773963cd
|
@ -146,12 +146,14 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
# set score to 0 where input_ids is a padding token
|
# set score to 0 where input_ids is a padding token
|
||||||
score *= input_ids.ne(0)
|
score *= input_ids.ne(0)
|
||||||
|
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
Frequency penalty as defined by OpenAI
|
Frequency penalty as defined by OpenAI in
|
||||||
|
https://platform.openai.com/docs/guides/text-generation/parameter-details
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frequency_penalty (`List[float]`):
|
frequency_penalty (`List[float]`):
|
||||||
|
@ -165,15 +167,17 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
batch_size, input_size = input_ids.size()
|
||||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
vocab_size = scores.size(1)
|
||||||
score = -torch.where(
|
|
||||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
|
||||||
)
|
|
||||||
# set score to 0 where input_ids is a padding token
|
|
||||||
score *= input_ids.ne(0)
|
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
# Calculate the frequency for each token so far
|
||||||
|
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
|
||||||
|
token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float))
|
||||||
|
token_freq /= input_size
|
||||||
|
|
||||||
|
# Apply the frequency penalty to logits
|
||||||
|
scores -= token_freq * self.penalty_tensor
|
||||||
|
return scores
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
self.penalty = [self.penalty[i] for i in indices]
|
self.penalty = [self.penalty[i] for i in indices]
|
||||||
|
|
Loading…
Reference in New Issue