fix(server): fix flash batch filtering (#220)
This commit is contained in:
parent
1ffea36ec2
commit
4b460e72fb
|
@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
position_ids.append(self.position_ids[idx])
|
position_ids.append(self.position_ids[idx])
|
||||||
cu_seqlens.append(cumulative_length + request_input_length)
|
cu_seqlens.append(cumulative_length + request_input_length)
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
# True index for past
|
||||||
|
past_key_values.append(self.past_key_values[2 * idx])
|
||||||
|
|
||||||
if not single_request:
|
if not single_request:
|
||||||
# True index for past
|
|
||||||
past_key_values.append(self.past_key_values[2 * idx])
|
|
||||||
# Add one padding
|
# Add one padding
|
||||||
past_key_values.append(self.past_pad)
|
past_key_values.append(self.past_pad)
|
||||||
|
|
||||||
|
@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
if single_request:
|
if single_request:
|
||||||
# Preallocate tensor for bs = 1 case
|
# Preallocate tensor for bs = 1 case
|
||||||
past_key_values = torch.nn.functional.pad(
|
past_key_values = torch.nn.functional.pad(
|
||||||
self.past_key_values[0],
|
past_key_values[0],
|
||||||
(
|
(
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
|
Loading…
Reference in New Issue