feat(server): flash attention past key value optimizations (#213)
This commit is contained in:
parent
274513e6a3
commit
ac8c0f6fe4
|
@ -38,7 +38,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# cumulative sequence lengths
|
# cumulative sequence lengths
|
||||||
cu_seqlens: List[int]
|
cu_seqlens: List[int]
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
past_key_values: Optional[List[torch.Tensor]]
|
past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
|
@ -53,6 +53,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
|
# Constant shared tensor, ref here just so that it's accessible in concatentate()
|
||||||
|
past_pad: Optional[torch.Tensor]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id, requests=self.requests, size=len(self)
|
id=self.batch_id, requests=self.requests, size=len(self)
|
||||||
|
@ -149,6 +152,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
if len(requests) == len(self):
|
if len(requests) == len(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
single_request = len(requests) == 1
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
|
|
||||||
|
@ -182,7 +187,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
position_ids.append(self.position_ids[idx])
|
position_ids.append(self.position_ids[idx])
|
||||||
cu_seqlens.append(cumulative_length + request_input_length)
|
cu_seqlens.append(cumulative_length + request_input_length)
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
past_key_values.append(self.past_key_values[idx])
|
if not single_request:
|
||||||
|
past_key_values.append(self.past_key_values[2 * idx])
|
||||||
|
past_key_values.append(self.past_key_values[1])
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
||||||
|
@ -196,6 +203,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
cumulative_length += request_input_length
|
cumulative_length += request_input_length
|
||||||
|
|
||||||
|
if single_request:
|
||||||
|
# Preallocate tensor for bs = 1 case
|
||||||
|
past_key_values = torch.nn.functional.pad(
|
||||||
|
self.past_key_values[0],
|
||||||
|
(0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
|
@ -256,7 +270,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
|
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
|
||||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||||
past_key_values.extend(batch.past_key_values)
|
if len(batch) != 1:
|
||||||
|
past_key_values.extend(batch.past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]])
|
||||||
|
past_key_values.append(batch.past_pad)
|
||||||
|
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
|
||||||
|
@ -303,6 +321,7 @@ class FlashCausalLM(Model):
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
decode_buffer: int = 3,
|
decode_buffer: int = 3,
|
||||||
):
|
):
|
||||||
|
self.past_pad = None
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
@ -359,10 +378,8 @@ class FlashCausalLM(Model):
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
# Shortcut when batch_size == 1
|
# Shortcut when batch_size == 1
|
||||||
if len(batch) == 1:
|
if len(batch) == 1:
|
||||||
input_ids = batch.input_ids[0].view(-1)
|
# No need to slice this down
|
||||||
past_key_values = (
|
past_key_values = batch.past_key_values
|
||||||
batch.past_key_values[0] if batch.past_key_values is not None else None
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Concatenate tensors
|
# Concatenate tensors
|
||||||
input_ids = torch.cat(batch.input_ids).view(-1)
|
input_ids = torch.cat(batch.input_ids).view(-1)
|
||||||
|
@ -392,7 +409,18 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# Initialize past_key_values in prefill
|
# Initialize past_key_values in prefill
|
||||||
if batch.past_key_values is None:
|
if batch.past_key_values is None:
|
||||||
batch.past_key_values = [None] * len(batch)
|
# Initialize past padding tensor
|
||||||
|
if self.past_pad is None:
|
||||||
|
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
|
||||||
|
# Set in batch in case it needs to be used later in concatenate()
|
||||||
|
batch.past_pad = self.past_pad
|
||||||
|
if len(batch) == 1:
|
||||||
|
# Preallocate tensor for bs = 1 case
|
||||||
|
batch.past_key_values = torch.nn.functional.pad(
|
||||||
|
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch.past_key_values = [None, self.past_pad] * len(batch)
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
|
@ -477,21 +505,10 @@ class FlashCausalLM(Model):
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text, stopping_criteria.current_tokens, reason, seed
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
)
|
)
|
||||||
|
|
||||||
# CAUTION: generation will be stopped so no need to pad
|
|
||||||
# This will make the next forward crash if the request does not get filtered
|
|
||||||
new_input_length = input_length
|
|
||||||
past = present[:, start_index:end_index]
|
|
||||||
else:
|
else:
|
||||||
stopped = False
|
stopped = False
|
||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Pad present for next iter attention
|
|
||||||
new_input_length = input_length + 1
|
|
||||||
past = torch.nn.functional.pad(
|
|
||||||
present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
|
@ -522,6 +539,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_ids[i] = next_token_id
|
batch.input_ids[i] = next_token_id
|
||||||
|
@ -532,7 +550,8 @@ class FlashCausalLM(Model):
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.all_input_ids_tensor[i] = all_input_ids_tensor
|
batch.all_input_ids_tensor[i] = all_input_ids_tensor
|
||||||
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
||||||
batch.past_key_values[i] = past
|
if len(batch) != 1:
|
||||||
|
batch.past_key_values[i * 2] = present[:, start_index:end_index]
|
||||||
# Cumulative sum
|
# Cumulative sum
|
||||||
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
|
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue