fix(server): do not warp prefill logits (#116)
This commit is contained in:
parent
1a2d68250a
commit
c0795de2f2
|
@ -75,7 +75,11 @@ class NextTokenChooser:
|
|||
|
||||
def __call__(self, input_ids, scores):
|
||||
# Warp logits
|
||||
scores = self.warpers(input_ids, scores)
|
||||
if scores.shape[0] > 1:
|
||||
# only warp the last token logits
|
||||
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
|
||||
else:
|
||||
scores = self.warpers(input_ids, scores)
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
|
|
Loading…
Reference in New Issue