fix(server): fix flash batch filtering (#220)
This commit is contained in:
parent
1ffea36ec2
commit
4b460e72fb
|
@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch):
|
|||
position_ids.append(self.position_ids[idx])
|
||||
cu_seqlens.append(cumulative_length + request_input_length)
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
# True index for past
|
||||
past_key_values.append(self.past_key_values[2 * idx])
|
||||
|
||||
if not single_request:
|
||||
# True index for past
|
||||
past_key_values.append(self.past_key_values[2 * idx])
|
||||
# Add one padding
|
||||
past_key_values.append(self.past_pad)
|
||||
|
||||
|
@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch):
|
|||
if single_request:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
past_key_values = torch.nn.functional.pad(
|
||||
self.past_key_values[0],
|
||||
past_key_values[0],
|
||||
(
|
||||
0,
|
||||
0,
|
||||
|
|
Loading…
Reference in New Issue