feat(server): flash attention past key value optimizations (#213)
This commit is contained in:
parent
274513e6a3
commit
ac8c0f6fe4
|
@ -38,7 +38,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# cumulative sequence lengths
|
||||
cu_seqlens: List[int]
|
||||
max_seqlen: int
|
||||
past_key_values: Optional[List[torch.Tensor]]
|
||||
past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]]
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[List[int]]
|
||||
|
@ -53,6 +53,9 @@ class FlashCausalLMBatch(Batch):
|
|||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
# Constant shared tensor, ref here just so that it's accessible in concatentate()
|
||||
past_pad: Optional[torch.Tensor]
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id, requests=self.requests, size=len(self)
|
||||
|
@ -149,6 +152,8 @@ class FlashCausalLMBatch(Batch):
|
|||
if len(requests) == len(self):
|
||||
return self
|
||||
|
||||
single_request = len(requests) == 1
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
||||
|
@ -182,7 +187,9 @@ class FlashCausalLMBatch(Batch):
|
|||
position_ids.append(self.position_ids[idx])
|
||||
cu_seqlens.append(cumulative_length + request_input_length)
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
past_key_values.append(self.past_key_values[idx])
|
||||
if not single_request:
|
||||
past_key_values.append(self.past_key_values[2 * idx])
|
||||
past_key_values.append(self.past_key_values[1])
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
||||
|
@ -196,6 +203,13 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
cumulative_length += request_input_length
|
||||
|
||||
if single_request:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
past_key_values = torch.nn.functional.pad(
|
||||
self.past_key_values[0],
|
||||
(0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens)
|
||||
)
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
|
@ -256,7 +270,11 @@ class FlashCausalLMBatch(Batch):
|
|||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
|
||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||
past_key_values.extend(batch.past_key_values)
|
||||
if len(batch) != 1:
|
||||
past_key_values.extend(batch.past_key_values)
|
||||
else:
|
||||
past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]])
|
||||
past_key_values.append(batch.past_pad)
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||
|
@ -303,6 +321,7 @@ class FlashCausalLM(Model):
|
|||
quantize: bool = False,
|
||||
decode_buffer: int = 3,
|
||||
):
|
||||
self.past_pad = None
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
@ -359,10 +378,8 @@ class FlashCausalLM(Model):
|
|||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
# Shortcut when batch_size == 1
|
||||
if len(batch) == 1:
|
||||
input_ids = batch.input_ids[0].view(-1)
|
||||
past_key_values = (
|
||||
batch.past_key_values[0] if batch.past_key_values is not None else None
|
||||
)
|
||||
# No need to slice this down
|
||||
past_key_values = batch.past_key_values
|
||||
else:
|
||||
# Concatenate tensors
|
||||
input_ids = torch.cat(batch.input_ids).view(-1)
|
||||
|
@ -392,7 +409,18 @@ class FlashCausalLM(Model):
|
|||
|
||||
# Initialize past_key_values in prefill
|
||||
if batch.past_key_values is None:
|
||||
batch.past_key_values = [None] * len(batch)
|
||||
# Initialize past padding tensor
|
||||
if self.past_pad is None:
|
||||
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
|
||||
# Set in batch in case it needs to be used later in concatenate()
|
||||
batch.past_pad = self.past_pad
|
||||
if len(batch) == 1:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
batch.past_key_values = torch.nn.functional.pad(
|
||||
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
|
||||
)
|
||||
else:
|
||||
batch.past_key_values = [None, self.past_pad] * len(batch)
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
|
@ -477,21 +505,10 @@ class FlashCausalLM(Model):
|
|||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
|
||||
# CAUTION: generation will be stopped so no need to pad
|
||||
# This will make the next forward crash if the request does not get filtered
|
||||
new_input_length = input_length
|
||||
past = present[:, start_index:end_index]
|
||||
else:
|
||||
stopped = False
|
||||
generated_text = None
|
||||
|
||||
# Pad present for next iter attention
|
||||
new_input_length = input_length + 1
|
||||
past = torch.nn.functional.pad(
|
||||
present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1)
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
|
@ -522,6 +539,7 @@ class FlashCausalLM(Model):
|
|||
|
||||
generations.append(generation)
|
||||
cumulative_length += input_length
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Update values
|
||||
batch.input_ids[i] = next_token_id
|
||||
|
@ -532,7 +550,8 @@ class FlashCausalLM(Model):
|
|||
batch.all_input_ids[i] = all_input_ids
|
||||
batch.all_input_ids_tensor[i] = all_input_ids_tensor
|
||||
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
||||
batch.past_key_values[i] = past
|
||||
if len(batch) != 1:
|
||||
batch.past_key_values[i * 2] = present[:, start_index:end_index]
|
||||
# Cumulative sum
|
||||
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
|
||||
|
||||
|
|
Loading…
Reference in New Issue