This commit is contained in:
Nicolas Patry 2024-09-25 20:41:40 +02:00
parent 44cdb00bbb
commit 31a4c24f74
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 183 additions and 101 deletions

View File

@ -758,8 +758,8 @@ class MllamaTextCrossAttention(nn.Module):
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
past_key_value[self.layer_idx][0],
past_key_value[self.layer_idx][1],
)
else:
raise ValueError(
@ -850,6 +850,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
self.cross_attn_mlp_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
)
self.layer_idx = layer_idx
def forward(
self,
@ -862,24 +863,75 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if past_key_value is not None:
is_mixed = False
if cross_attention_states is None:
out_hidden_states = hidden_states[:]
indices = []
for i, k in enumerate(past_key_value[self.layer_idx][0]):
if isinstance(k, torch.Tensor):
indices.append(i)
from loguru import logger
hidden_states, attn_weights, past_key_value = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
past_key_value=past_key_value,
cache_position=cache_position,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
logger.info(f"Indices {indices}")
if len(indices) == 0:
return hidden_states
is_mixed = True
if len(indices) == hidden_states.shape[0]:
is_mixed = False
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if full_text_row_masked_out_mask is not None:
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
if is_mixed:
hidden_states = hidden_states[indices]
# Dirty hack
_past_key_value = [None] * len(past_key_value)
_past_key_value[self.layer_idx] = (
torch.stack(
[
k
for i, k in enumerate(past_key_value[self.layer_idx][0])
if i in indices
],
dim=0,
),
torch.stack(
[
k
for i, k in enumerate(past_key_value[self.layer_idx][1])
if i in indices
],
dim=0,
),
)
logger.info(f"Hidden states {hidden_states.shape}")
logger.info(f"k {_past_key_value[self.layer_idx][0].shape}")
logger.info(f"v {_past_key_value[self.layer_idx][1].shape}")
past_key_value = _past_key_value
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights, past_key_value = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
past_key_value=past_key_value,
cache_position=cache_position,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if full_text_row_masked_out_mask is not None:
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
if is_mixed:
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
from loguru import logger
logger.info(f"After Hidden states {hidden_states.shape}")
return hidden_states
@ -1243,18 +1295,18 @@ class MllamaTextModel(nn.Module):
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
if (
idx in self.cross_attention_layers
and cross_attention_states is None
and (
past_key_values is None
or (
past_key_values is not None
and past_key_values.get_seq_length(idx) == 0
)
)
):
continue
# if (
# idx in self.cross_attention_layers
# and cross_attention_states is None
# and (
# past_key_values is None
# or (
# past_key_values is not None
# and any(past_key_values.get_seq_length(idx) == 0
# )
# )
# ):
# continue
layer_outputs = decoder_layer(
hidden_states,

View File

@ -360,7 +360,15 @@ class IdeficsCausalLMBatch(Batch):
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if len(past_keys.shape) == 3:
if not isinstance(past_keys, torch.Tensor):
past_keys = [k for i, k in enumerate(past_keys) if i in keep_indices]
past_values = [
k for i, k in enumerate(past_values) if i in keep_indices
]
layer[0] = past_keys
layer[1] = past_values
continue
elif len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
@ -530,7 +538,14 @@ class IdeficsCausalLMBatch(Batch):
# And ensure that we can update tensors in-place
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
[
(
t.view(len(batch), -1, *t.shape[-2:])
if isinstance(t, torch.Tensor)
else t
)
for t in layer
]
for layer in batch.past_key_values
]
elif len(batch.past_key_values[0][0].shape) == 3:
@ -569,83 +584,98 @@ class IdeficsCausalLMBatch(Batch):
# Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)):
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_keys_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
)
if any(
not isinstance(batch.past_key_values[j][0], torch.Tensor)
for batch in batches
):
# XXX: Special handling for cross attention for mllama
padded_past_keys = [
k for batch in batches for k in batch.past_key_values[j][0]
]
padded_past_values = [
k for batch in batches for k in batch.past_key_values[j][1]
]
past_key_values.append([padded_past_keys, padded_past_values])
else:
_padded_past_keys_shape = padded_past_keys_shape
padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if past_keys.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_keys.shape[2]
if batch.keys_head_dim_last:
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_keys_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
)
else:
# BLOOM case
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
_padded_past_keys_shape = padded_past_keys_shape
padded_past_keys = first_past_kvs[j][0].new_zeros(
_padded_past_keys_shape
)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if past_keys.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_keys.shape[2]
if batch.keys_head_dim_last:
padded_past_keys[
start_index:end_index, :, -past_seq_len:, :
] = past_keys[:, :, -past_seq_len:, :]
else:
# BLOOM case
padded_past_keys[
start_index:end_index, :, :, -past_seq_len:
] = past_keys[:, :, :, -past_seq_len:]
del past_keys
start_index = end_index
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_values_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
)
del past_keys
start_index = end_index
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_values_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
else:
_padded_past_values_shape = padded_past_values_shape
padded_past_values = first_past_kvs[j][1].new_zeros(
_padded_past_values_shape
)
else:
_padded_past_values_shape = padded_past_values_shape
padded_past_values = first_past_kvs[j][1].new_zeros(
_padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if past_values.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_values.shape[2]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if past_values.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_values.shape[2]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
start_index = end_index
# Update values
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
past_key_values.append([padded_past_keys, padded_past_values])
return cls(
batch_id=batches[0].batch_id,