feat(server): flash attention past key value optimizations (#213)

This commit is contained in:
Nick Hill 2023-04-21 05:57:18 -07:00 committed by GitHub
parent 274513e6a3
commit ac8c0f6fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 39 additions and 20 deletions

View File

@ -38,7 +38,7 @@ class FlashCausalLMBatch(Batch):
# cumulative sequence lengths
cu_seqlens: List[int]
max_seqlen: int
past_key_values: Optional[List[torch.Tensor]]
past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]]
# All tokens
all_input_ids: List[List[int]]
@ -53,6 +53,9 @@ class FlashCausalLMBatch(Batch):
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Constant shared tensor, ref here just so that it's accessible in concatentate()
past_pad: Optional[torch.Tensor]
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id, requests=self.requests, size=len(self)
@ -149,6 +152,8 @@ class FlashCausalLMBatch(Batch):
if len(requests) == len(self):
return self
single_request = len(requests) == 1
# Cumulative length
cumulative_length = 0
@ -182,7 +187,9 @@ class FlashCausalLMBatch(Batch):
position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length)
past_key_values.append(self.past_key_values[idx])
if not single_request:
past_key_values.append(self.past_key_values[2 * idx])
past_key_values.append(self.past_key_values[1])
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
@ -196,6 +203,13 @@ class FlashCausalLMBatch(Batch):
cumulative_length += request_input_length
if single_request:
# Preallocate tensor for bs = 1 case
past_key_values = torch.nn.functional.pad(
self.past_key_values[0],
(0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens)
)
return FlashCausalLMBatch(
batch_id=self.batch_id,
requests=requests,
@ -256,7 +270,11 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
past_key_values.extend(batch.past_key_values)
if len(batch) != 1:
past_key_values.extend(batch.past_key_values)
else:
past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]])
past_key_values.append(batch.past_pad)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
@ -303,6 +321,7 @@ class FlashCausalLM(Model):
quantize: bool = False,
decode_buffer: int = 3,
):
self.past_pad = None
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
@ -359,10 +378,8 @@ class FlashCausalLM(Model):
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# Shortcut when batch_size == 1
if len(batch) == 1:
input_ids = batch.input_ids[0].view(-1)
past_key_values = (
batch.past_key_values[0] if batch.past_key_values is not None else None
)
# No need to slice this down
past_key_values = batch.past_key_values
else:
# Concatenate tensors
input_ids = torch.cat(batch.input_ids).view(-1)
@ -392,7 +409,18 @@ class FlashCausalLM(Model):
# Initialize past_key_values in prefill
if batch.past_key_values is None:
batch.past_key_values = [None] * len(batch)
# Initialize past padding tensor
if self.past_pad is None:
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
# Set in batch in case it needs to be used later in concatenate()
batch.past_pad = self.past_pad
if len(batch) == 1:
# Preallocate tensor for bs = 1 case
batch.past_key_values = torch.nn.functional.pad(
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
)
else:
batch.past_key_values = [None, self.past_pad] * len(batch)
# Cumulative length
cumulative_length = 0
@ -477,21 +505,10 @@ class FlashCausalLM(Model):
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
# CAUTION: generation will be stopped so no need to pad
# This will make the next forward crash if the request does not get filtered
new_input_length = input_length
past = present[:, start_index:end_index]
else:
stopped = False
generated_text = None
# Pad present for next iter attention
new_input_length = input_length + 1
past = torch.nn.functional.pad(
present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1)
)
# Prefill
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
@ -522,6 +539,7 @@ class FlashCausalLM(Model):
generations.append(generation)
cumulative_length += input_length
new_input_length = input_length + 1
# Update values
batch.input_ids[i] = next_token_id
@ -532,7 +550,8 @@ class FlashCausalLM(Model):
batch.all_input_ids[i] = all_input_ids
batch.all_input_ids_tensor[i] = all_input_ids_tensor
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
batch.past_key_values[i] = past
if len(batch) != 1:
batch.past_key_values[i * 2] = present[:, start_index:end_index]
# Cumulative sum
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length