fix(server): fix flash batch filtering (#220)

This commit is contained in:
OlivierDehaene 2023-04-21 20:26:01 +02:00 committed by GitHub
parent 1ffea36ec2
commit 4b460e72fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions

View File

@ -188,9 +188,10 @@ class FlashCausalLMBatch(Batch):
position_ids.append(self.position_ids[idx]) position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length) cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
# True index for past
past_key_values.append(self.past_key_values[2 * idx])
if not single_request: if not single_request:
# True index for past
past_key_values.append(self.past_key_values[2 * idx])
# Add one padding # Add one padding
past_key_values.append(self.past_pad) past_key_values.append(self.past_pad)
@ -209,7 +210,7 @@ class FlashCausalLMBatch(Batch):
if single_request: if single_request:
# Preallocate tensor for bs = 1 case # Preallocate tensor for bs = 1 case
past_key_values = torch.nn.functional.pad( past_key_values = torch.nn.functional.pad(
self.past_key_values[0], past_key_values[0],
( (
0, 0,
0, 0,