Fixing frequency penalty (#1811)

Thank you so much for the work you are doing, this is my little
contribution to this great thing you have built. I hope it is useful and
helpful, please don't hesitate to discuss any matters that are not
clear!

I am basing my implementation of frequency penalty on OpenAI's
implementation:
https://platform.openai.com/docs/guides/text-generation/parameter-details

The problem I see with TGI's current implementation is that is not
taking into account the frequency of tokens which have already been
sampled in the current generation stream. Also, the scaling is of the
adjusted token logits is done differently for positive and negative
logits. While in OpenAI's implementation token frequency is taking into
account and the scaling is always done with a subtraction (if penalty is
positive) or add operation (if penalty is negative).

This leads to corrupt generations as I mentioned in issue #1810 .
Moreover, after my tests, other issues are also gone like the one about
some request's with ``penalty_frequency = 1.0`` overruling other
requests (with ``frequency_penalty = 0.0``) in the same batch and
therefore corrupting all generations in the batch. Basically, padding
does not affect this implementation so I believe this ``score *=
input_ids.ne(0)`` is not needed anymore.



Frequency penalty | -1.0 | 0.0 | 1.0
-- | -- | -- | --
Before my change | https://paste.mozilla.org/JxqGJkWY |
https://paste.mozilla.org/hrztJ56h | https://paste.mozilla.org/pBSEH2zw
After my change | https://paste.mozilla.org/7gXCi7zo |
https://paste.mozilla.org/ZR9rJ92g | https://paste.mozilla.org/gHaD2YnC

---------

Co-authored-by: martini <martin.iglesiasgoyanes@adyen.com>
This commit is contained in:
Martin Iglesias Goyanes 2024-04-30 12:13:23 +02:00 committed by GitHub
parent f6615080b9
commit 9192de57cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 9 deletions

2
.gitignore vendored
View File

@ -11,3 +11,5 @@ server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
data/

View File

@ -151,7 +151,8 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Frequency penalty as defined by OpenAI in
https://platform.openai.com/docs/guides/text-generation/parameter-details
Args:
frequency_penalty (`List[float]`):
@ -165,15 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
batch_size, input_size = input_ids.size()
vocab_size = scores.size(1)
return scores.scatter_add_(1, input_ids, score)
# Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
)
token_freq /= input_size
# Apply the frequency penalty to logits
scores -= token_freq * self.penalty_tensor
return scores
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]