From 949b889c3353c52068284f4b1de4c32ac40afbae Mon Sep 17 00:00:00 2001 From: martini Date: Thu, 25 Apr 2024 23:47:20 +0200 Subject: [PATCH] fix: take into account logits frequency so far in a generation stream when apply freq penalty --- .../utils/logits_process.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 2decee53..8cf85c9f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -146,12 +146,14 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): # set score to 0 where input_ids is a padding token score *= input_ids.ne(0) + return scores.scatter_add_(1, input_ids, score) class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): r""" - Frequency penalty as defined by OpenAI + Frequency penalty as defined by OpenAI in + https://platform.openai.com/docs/guides/text-generation/parameter-details Args: frequency_penalty (`List[float]`): @@ -165,15 +167,17 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): ).unsqueeze(1) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - score = torch.gather(scores, 1, input_ids) - # if score < 0 then penalty has to be multiplied to reduce the previous token probability - score = -torch.where( - score < 0, score * self.penalty_tensor, score / self.penalty_tensor - ) - # set score to 0 where input_ids is a padding token - score *= input_ids.ne(0) + batch_size, input_size = input_ids.size() + vocab_size = scores.size(1) - return scores.scatter_add_(1, input_ids, score) + # Calculate the frequency for each token so far + token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) + token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float)) + token_freq /= input_size + + # Apply the frequency penalty to logits + scores -= token_freq * self.penalty_tensor + return scores def filter(self, indices): self.penalty = [self.penalty[i] for i in indices]