fix(server): do not warp prefill logits (#116)
This commit is contained in:
parent
1a2d68250a
commit
c0795de2f2
|
@ -75,7 +75,11 @@ class NextTokenChooser:
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
# Warp logits
|
# Warp logits
|
||||||
scores = self.warpers(input_ids, scores)
|
if scores.shape[0] > 1:
|
||||||
|
# only warp the last token logits
|
||||||
|
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
|
||||||
|
else:
|
||||||
|
scores = self.warpers(input_ids, scores)
|
||||||
|
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
|
|
Loading…
Reference in New Issue