Mllama
This commit is contained in:
parent
44cdb00bbb
commit
31a4c24f74
|
@ -758,8 +758,8 @@ class MllamaTextCrossAttention(nn.Module):
|
|||
|
||||
elif cache_position[0] != 0:
|
||||
key_states, value_states = (
|
||||
past_key_value.key_cache[self.layer_idx],
|
||||
past_key_value.value_cache[self.layer_idx],
|
||||
past_key_value[self.layer_idx][0],
|
||||
past_key_value[self.layer_idx][1],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -850,6 +850,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|||
self.cross_attn_mlp_gate = torch.nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
|
||||
)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -862,24 +863,75 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
if past_key_value is not None:
|
||||
is_mixed = False
|
||||
if cross_attention_states is None:
|
||||
out_hidden_states = hidden_states[:]
|
||||
indices = []
|
||||
for i, k in enumerate(past_key_value[self.layer_idx][0]):
|
||||
if isinstance(k, torch.Tensor):
|
||||
indices.append(i)
|
||||
from loguru import logger
|
||||
|
||||
hidden_states, attn_weights, past_key_value = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=cross_attention_mask,
|
||||
cross_attention_states=cross_attention_states,
|
||||
past_key_value=past_key_value,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||||
logger.info(f"Indices {indices}")
|
||||
if len(indices) == 0:
|
||||
return hidden_states
|
||||
is_mixed = True
|
||||
if len(indices) == hidden_states.shape[0]:
|
||||
is_mixed = False
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if full_text_row_masked_out_mask is not None:
|
||||
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
|
||||
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||||
if is_mixed:
|
||||
hidden_states = hidden_states[indices]
|
||||
# Dirty hack
|
||||
_past_key_value = [None] * len(past_key_value)
|
||||
_past_key_value[self.layer_idx] = (
|
||||
torch.stack(
|
||||
[
|
||||
k
|
||||
for i, k in enumerate(past_key_value[self.layer_idx][0])
|
||||
if i in indices
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
torch.stack(
|
||||
[
|
||||
k
|
||||
for i, k in enumerate(past_key_value[self.layer_idx][1])
|
||||
if i in indices
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
)
|
||||
logger.info(f"Hidden states {hidden_states.shape}")
|
||||
logger.info(f"k {_past_key_value[self.layer_idx][0].shape}")
|
||||
logger.info(f"v {_past_key_value[self.layer_idx][1].shape}")
|
||||
past_key_value = _past_key_value
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
hidden_states, attn_weights, past_key_value = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=cross_attention_mask,
|
||||
cross_attention_states=cross_attention_states,
|
||||
past_key_value=past_key_value,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if full_text_row_masked_out_mask is not None:
|
||||
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
|
||||
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||||
|
||||
if is_mixed:
|
||||
out_hidden_states[indices] = hidden_states
|
||||
hidden_states = out_hidden_states
|
||||
from loguru import logger
|
||||
|
||||
logger.info(f"After Hidden states {hidden_states.shape}")
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -1243,18 +1295,18 @@ class MllamaTextModel(nn.Module):
|
|||
# decoder layers
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if (
|
||||
idx in self.cross_attention_layers
|
||||
and cross_attention_states is None
|
||||
and (
|
||||
past_key_values is None
|
||||
or (
|
||||
past_key_values is not None
|
||||
and past_key_values.get_seq_length(idx) == 0
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
# if (
|
||||
# idx in self.cross_attention_layers
|
||||
# and cross_attention_states is None
|
||||
# and (
|
||||
# past_key_values is None
|
||||
# or (
|
||||
# past_key_values is not None
|
||||
# and any(past_key_values.get_seq_length(idx) == 0
|
||||
# )
|
||||
# )
|
||||
# ):
|
||||
# continue
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
|
|
|
@ -360,7 +360,15 @@ class IdeficsCausalLMBatch(Batch):
|
|||
past_kv_length = max_input_length - 1
|
||||
for layer in self.past_key_values:
|
||||
past_keys, past_values = layer
|
||||
if len(past_keys.shape) == 3:
|
||||
if not isinstance(past_keys, torch.Tensor):
|
||||
past_keys = [k for i, k in enumerate(past_keys) if i in keep_indices]
|
||||
past_values = [
|
||||
k for i, k in enumerate(past_values) if i in keep_indices
|
||||
]
|
||||
layer[0] = past_keys
|
||||
layer[1] = past_values
|
||||
continue
|
||||
elif len(past_keys.shape) == 3:
|
||||
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
|
||||
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
|
||||
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
|
||||
|
@ -530,7 +538,14 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# And ensure that we can update tensors in-place
|
||||
if isinstance(batch.past_key_values[0], tuple):
|
||||
batch.past_key_values = [
|
||||
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
|
||||
[
|
||||
(
|
||||
t.view(len(batch), -1, *t.shape[-2:])
|
||||
if isinstance(t, torch.Tensor)
|
||||
else t
|
||||
)
|
||||
for t in layer
|
||||
]
|
||||
for layer in batch.past_key_values
|
||||
]
|
||||
elif len(batch.past_key_values[0][0].shape) == 3:
|
||||
|
@ -569,83 +584,98 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# Iterate over attention layers
|
||||
# Concatenate past key values layer by layer to allow incremental garbage collection
|
||||
for j in range(len(first_past_kvs)):
|
||||
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
|
||||
if seqlen > max_input_length:
|
||||
# XXX: This is probably a cross attention key value
|
||||
# If not this is ok
|
||||
_padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
_num_heads,
|
||||
seqlen,
|
||||
_head_dim,
|
||||
)
|
||||
if any(
|
||||
not isinstance(batch.past_key_values[j][0], torch.Tensor)
|
||||
for batch in batches
|
||||
):
|
||||
# XXX: Special handling for cross attention for mllama
|
||||
padded_past_keys = [
|
||||
k for batch in batches for k in batch.past_key_values[j][0]
|
||||
]
|
||||
padded_past_values = [
|
||||
k for batch in batches for k in batch.past_key_values[j][1]
|
||||
]
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
else:
|
||||
_padded_past_keys_shape = padded_past_keys_shape
|
||||
|
||||
padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_keys = batch.past_key_values[j][0]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][0] = None
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the keys to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if past_keys.shape[2] > past_seq_len:
|
||||
# XXX: This is a cross attention kv in mllama
|
||||
past_seq_len = past_keys.shape[2]
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_keys[:, :, -past_seq_len:, :]
|
||||
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
|
||||
if seqlen > max_input_length:
|
||||
# XXX: This is probably a cross attention key value
|
||||
# If not this is ok
|
||||
_padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
_num_heads,
|
||||
seqlen,
|
||||
_head_dim,
|
||||
)
|
||||
else:
|
||||
# BLOOM case
|
||||
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||
past_keys[:, :, :, -past_seq_len:]
|
||||
_padded_past_keys_shape = padded_past_keys_shape
|
||||
|
||||
padded_past_keys = first_past_kvs[j][0].new_zeros(
|
||||
_padded_past_keys_shape
|
||||
)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_keys = batch.past_key_values[j][0]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][0] = None
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the keys to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if past_keys.shape[2] > past_seq_len:
|
||||
# XXX: This is a cross attention kv in mllama
|
||||
past_seq_len = past_keys.shape[2]
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, -past_seq_len:, :
|
||||
] = past_keys[:, :, -past_seq_len:, :]
|
||||
else:
|
||||
# BLOOM case
|
||||
padded_past_keys[
|
||||
start_index:end_index, :, :, -past_seq_len:
|
||||
] = past_keys[:, :, :, -past_seq_len:]
|
||||
del past_keys
|
||||
|
||||
start_index = end_index
|
||||
|
||||
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
|
||||
if seqlen > max_input_length:
|
||||
# XXX: This is probably a cross attention key value
|
||||
# If not this is ok
|
||||
_padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
_num_heads,
|
||||
seqlen,
|
||||
_head_dim,
|
||||
)
|
||||
del past_keys
|
||||
|
||||
start_index = end_index
|
||||
|
||||
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
|
||||
if seqlen > max_input_length:
|
||||
# XXX: This is probably a cross attention key value
|
||||
# If not this is ok
|
||||
_padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
_num_heads,
|
||||
seqlen,
|
||||
_head_dim,
|
||||
else:
|
||||
_padded_past_values_shape = padded_past_values_shape
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||
_padded_past_values_shape
|
||||
)
|
||||
else:
|
||||
_padded_past_values_shape = padded_past_values_shape
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||
_padded_past_values_shape
|
||||
)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_values = batch.past_key_values[j][1]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][1] = None
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_values = batch.past_key_values[j][1]
|
||||
# Clear reference to the original tensor
|
||||
batch.past_key_values[j][1] = None
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if past_values.shape[2] > past_seq_len:
|
||||
# XXX: This is a cross attention kv in mllama
|
||||
past_seq_len = past_values.shape[2]
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_values[:, :, -past_seq_len:, :]
|
||||
)
|
||||
del past_values
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + len(batch)
|
||||
# We slice the past values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if past_values.shape[2] > past_seq_len:
|
||||
# XXX: This is a cross attention kv in mllama
|
||||
past_seq_len = past_values.shape[2]
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_values[:, :, -past_seq_len:, :]
|
||||
)
|
||||
del past_values
|
||||
|
||||
# Update values
|
||||
start_index = end_index
|
||||
# Update values
|
||||
start_index = end_index
|
||||
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
past_key_values.append([padded_past_keys, padded_past_values])
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
|
|
Loading…
Reference in New Issue