fix(server): do not warp prefill logits (#116)

This commit is contained in:
OlivierDehaene 2023-03-09 13:00:10 +01:00 committed by GitHub
parent 1a2d68250a
commit c0795de2f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -75,7 +75,11 @@ class NextTokenChooser:
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
if scores.shape[0] > 1:
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)