fix(server): fix flash batch filtering (#220)

This commit is contained in:
OlivierDehaene 2023-04-21 20:26:01 +02:00 committed by GitHub
parent 1ffea36ec2
commit 4b460e72fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions

View File

@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch):
position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length)
if not single_request:
# True index for past
past_key_values.append(self.past_key_values[2 * idx])
if not single_request:
# Add one padding
past_key_values.append(self.past_pad)
@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch):
if single_request:
# Preallocate tensor for bs = 1 case
past_key_values = torch.nn.functional.pad(
self.past_key_values[0],
past_key_values[0],
(
0,
0,