parent
343437c7b5
commit
db4cb5e4ed
|
@ -25,6 +25,7 @@ from torch.nn import functional as F
|
|||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import rotary_emb
|
||||
|
@ -554,7 +555,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
|
@ -564,7 +566,9 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
len(self.layers),
|
||||
len(hidden_states),
|
||||
len(hidden_states)
|
||||
if pre_allocate_past_size is None
|
||||
else pre_allocate_past_size,
|
||||
2,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
|
@ -572,6 +576,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
)
|
||||
layer_past_present_indices = None
|
||||
cu_seqlens_q = None
|
||||
slice_past_index = len(hidden_states)
|
||||
# Decode
|
||||
else:
|
||||
# Create indices from cumulative sequence lengths
|
||||
|
@ -579,6 +584,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
cu_seqlens_q = torch.arange(
|
||||
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
slice_past_index = None
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
|
@ -588,6 +594,13 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
# We added padding that now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
else past_key_values[i, :slice_past_index]
|
||||
)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
|
@ -595,7 +608,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values[i],
|
||||
layer_past_key_values,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
@ -638,10 +651,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states, present = self.model(
|
||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||
from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import rotary_emb
|
||||
|
@ -618,6 +619,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
|
||||
|
@ -627,7 +629,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
len(self.layers),
|
||||
len(hidden_states),
|
||||
len(hidden_states)
|
||||
if pre_allocate_past_size is None
|
||||
else pre_allocate_past_size,
|
||||
2,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
|
@ -635,6 +639,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
)
|
||||
layer_past_present_indices = None
|
||||
cu_seqlens_q = None
|
||||
slice_past_index = len(hidden_states)
|
||||
# Decode
|
||||
else:
|
||||
# Create indices from cumulative sequence lengths
|
||||
|
@ -642,6 +647,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
cu_seqlens_q = torch.arange(
|
||||
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
slice_past_index = None
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
|
@ -651,6 +657,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
# We added padding that now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
else past_key_values[i, :slice_past_index]
|
||||
)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
|
@ -658,7 +671,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values[i],
|
||||
layer_past_key_values,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
@ -714,10 +727,16 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states, present = self.gpt_neox(
|
||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
logits = self.embed_out(hidden_states)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
|||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
@ -484,7 +485,8 @@ class FlashSantacoderModel(nn.Module):
|
|||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
if self.tp_embeddings:
|
||||
|
@ -496,7 +498,9 @@ class FlashSantacoderModel(nn.Module):
|
|||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
len(self.h),
|
||||
len(hidden_states),
|
||||
len(hidden_states)
|
||||
if pre_allocate_past_size is None
|
||||
else pre_allocate_past_size,
|
||||
2,
|
||||
1,
|
||||
self.head_size,
|
||||
|
@ -504,6 +508,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
)
|
||||
layer_past_present_indices = None
|
||||
cu_seqlens_q = None
|
||||
slice_past_index = len(hidden_states)
|
||||
# Decode
|
||||
else:
|
||||
# Create indices from cumulative sequence lengths
|
||||
|
@ -511,15 +516,23 @@ class FlashSantacoderModel(nn.Module):
|
|||
cu_seqlens_q = torch.arange(
|
||||
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
slice_past_index = None
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.h):
|
||||
# We added padding that now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
else past_key_values[i, :slice_past_index]
|
||||
)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values[i],
|
||||
layer_past_key_values,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
@ -554,10 +567,16 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
|
|
|
@ -142,6 +142,7 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
past_pad=None,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -188,8 +189,10 @@ class FlashCausalLMBatch(Batch):
|
|||
cu_seqlens.append(cumulative_length + request_input_length)
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
if not single_request:
|
||||
# True index for past
|
||||
past_key_values.append(self.past_key_values[2 * idx])
|
||||
past_key_values.append(self.past_key_values[1])
|
||||
# Add one padding
|
||||
past_key_values.append(self.past_pad)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
|
||||
|
@ -207,7 +210,17 @@ class FlashCausalLMBatch(Batch):
|
|||
# Preallocate tensor for bs = 1 case
|
||||
past_key_values = torch.nn.functional.pad(
|
||||
self.past_key_values[0],
|
||||
(0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens)
|
||||
(
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
stopping_criterias[0].max_new_tokens
|
||||
- stopping_criterias[0].current_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
|
@ -270,10 +283,16 @@ class FlashCausalLMBatch(Batch):
|
|||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
|
||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||
|
||||
if len(batch) != 1:
|
||||
past_key_values.extend(batch.past_key_values)
|
||||
else:
|
||||
past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]])
|
||||
# past was pre-allocated for this batch
|
||||
# We need to slice to remove the padding
|
||||
past_key_values.append(
|
||||
batch.past_key_values[:, : batch.input_lengths[0]]
|
||||
)
|
||||
# Add one padding
|
||||
past_key_values.append(batch.past_pad)
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
|
@ -366,6 +385,7 @@ class FlashCausalLM(Model):
|
|||
cu_seqlens: torch.Tensor,
|
||||
max_s: int,
|
||||
past_key_values: Optional = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
return self.model.forward(
|
||||
|
@ -374,6 +394,7 @@ class FlashCausalLM(Model):
|
|||
cu_seqlens=cu_seqlens,
|
||||
max_s=max_s,
|
||||
past_key_values=past_key_values,
|
||||
pre_allocate_past_size=pre_allocate_past_size,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
|
@ -382,7 +403,9 @@ class FlashCausalLM(Model):
|
|||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
# Shortcut when batch_size == 1
|
||||
if len(batch) == 1:
|
||||
# No need to slice this down
|
||||
input_ids = batch.input_ids[0].view(-1)
|
||||
# Slice to remove extra padding
|
||||
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
|
||||
past_key_values = batch.past_key_values
|
||||
else:
|
||||
# Concatenate tensors
|
||||
|
@ -393,6 +416,16 @@ class FlashCausalLM(Model):
|
|||
else None
|
||||
)
|
||||
|
||||
# if prefill and bs == 1
|
||||
if past_key_values is None and len(batch) == 1:
|
||||
# Ask to pre-allocate kv to its max size
|
||||
# == number of tokens + max_new_tokens
|
||||
pre_allocate_past_size = (
|
||||
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
else:
|
||||
pre_allocate_past_size = None
|
||||
|
||||
# Concatenate when prefill, torch.tensor when decode
|
||||
position_ids = (
|
||||
torch.tensor(batch.position_ids, device=self.device)
|
||||
|
@ -409,21 +442,28 @@ class FlashCausalLM(Model):
|
|||
cu_seqlens,
|
||||
batch.max_seqlen,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
|
||||
# Initialize past_key_values in prefill
|
||||
if batch.past_key_values is None:
|
||||
# Initialize past padding tensor
|
||||
if self.past_pad is None:
|
||||
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
|
||||
self.past_pad = present.new_zeros(
|
||||
present.shape[0], 1, *present.shape[2:]
|
||||
)
|
||||
# Set in batch in case it needs to be used later in concatenate()
|
||||
batch.past_pad = self.past_pad
|
||||
if len(batch) == 1:
|
||||
# Preallocate tensor for bs = 1 case
|
||||
batch.past_key_values = torch.nn.functional.pad(
|
||||
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
|
||||
present,
|
||||
(0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens),
|
||||
)
|
||||
else:
|
||||
# Add padding after each sequence
|
||||
# This will have the correct shape after the final past_key_values concatenation before the model
|
||||
# forward
|
||||
batch.past_key_values = [None, self.past_pad] * len(batch)
|
||||
|
||||
# Cumulative length
|
||||
|
@ -555,6 +595,7 @@ class FlashCausalLM(Model):
|
|||
batch.all_input_ids_tensor[i] = all_input_ids_tensor
|
||||
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
||||
if len(batch) != 1:
|
||||
# Add each sequence before its padding
|
||||
batch.past_key_values[i * 2] = present[:, start_index:end_index]
|
||||
# Cumulative sum
|
||||
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
|
||||
|
|
|
@ -29,6 +29,7 @@ tracer = trace.get_tracer(__name__)
|
|||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
self.past_pad = None
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
@ -146,6 +147,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.past_pad = None
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
@ -33,6 +33,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.past_pad = None
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
@ -28,6 +28,7 @@ tracer = trace.get_tracer(__name__)
|
|||
|
||||
class FlashSantacoder(FlashCausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
self.past_pad = None
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||
|
@ -172,6 +173,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.past_pad = None
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
|
|
Loading…
Reference in New Issue