diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index b59b623a..90fcb812 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -758,8 +758,8 @@ class MllamaTextCrossAttention(nn.Module): elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], + past_key_value[self.layer_idx][0], + past_key_value[self.layer_idx][1], ) else: raise ValueError( @@ -850,6 +850,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): self.cross_attn_mlp_gate = torch.nn.Parameter( weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False ) + self.layer_idx = layer_idx def forward( self, @@ -862,24 +863,75 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if past_key_value is not None: + is_mixed = False + if cross_attention_states is None: + out_hidden_states = hidden_states[:] + indices = [] + for i, k in enumerate(past_key_value[self.layer_idx][0]): + if isinstance(k, torch.Tensor): + indices.append(i) + from loguru import logger - hidden_states, attn_weights, past_key_value = self.cross_attn( - hidden_states=hidden_states, - attention_mask=cross_attention_mask, - cross_attention_states=cross_attention_states, - past_key_value=past_key_value, - cache_position=cache_position, - ) - hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + logger.info(f"Indices {indices}") + if len(indices) == 0: + return hidden_states + is_mixed = True + if len(indices) == hidden_states.shape[0]: + is_mixed = False - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if full_text_row_masked_out_mask is not None: - hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore - hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + if is_mixed: + hidden_states = hidden_states[indices] + # Dirty hack + _past_key_value = [None] * len(past_key_value) + _past_key_value[self.layer_idx] = ( + torch.stack( + [ + k + for i, k in enumerate(past_key_value[self.layer_idx][0]) + if i in indices + ], + dim=0, + ), + torch.stack( + [ + k + for i, k in enumerate(past_key_value[self.layer_idx][1]) + if i in indices + ], + dim=0, + ), + ) + logger.info(f"Hidden states {hidden_states.shape}") + logger.info(f"k {_past_key_value[self.layer_idx][0].shape}") + logger.info(f"v {_past_key_value[self.layer_idx][1].shape}") + past_key_value = _past_key_value + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + past_key_value=past_key_value, + cache_position=cache_position, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if full_text_row_masked_out_mask is not None: + hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + if is_mixed: + out_hidden_states[indices] = hidden_states + hidden_states = out_hidden_states + from loguru import logger + + logger.info(f"After Hidden states {hidden_states.shape}") return hidden_states @@ -1243,18 +1295,18 @@ class MllamaTextModel(nn.Module): # decoder layers for idx, decoder_layer in enumerate(self.layers): - if ( - idx in self.cross_attention_layers - and cross_attention_states is None - and ( - past_key_values is None - or ( - past_key_values is not None - and past_key_values.get_seq_length(idx) == 0 - ) - ) - ): - continue + # if ( + # idx in self.cross_attention_layers + # and cross_attention_states is None + # and ( + # past_key_values is None + # or ( + # past_key_values is not None + # and any(past_key_values.get_seq_length(idx) == 0 + # ) + # ) + # ): + # continue layer_outputs = decoder_layer( hidden_states, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 09750af6..94e03efd 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -360,7 +360,15 @@ class IdeficsCausalLMBatch(Batch): past_kv_length = max_input_length - 1 for layer in self.past_key_values: past_keys, past_values = layer - if len(past_keys.shape) == 3: + if not isinstance(past_keys, torch.Tensor): + past_keys = [k for i, k in enumerate(past_keys) if i in keep_indices] + past_values = [ + k for i, k in enumerate(past_values) if i in keep_indices + ] + layer[0] = past_keys + layer[1] = past_values + continue + elif len(past_keys.shape) == 3: # Force past to be of dim [self_size, num_heads, ...] for easy indexing past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:]) past_values = past_values.view(len(self), -1, *past_values.shape[-2:]) @@ -530,7 +538,14 @@ class IdeficsCausalLMBatch(Batch): # And ensure that we can update tensors in-place if isinstance(batch.past_key_values[0], tuple): batch.past_key_values = [ - [t.view(len(batch), -1, *t.shape[-2:]) for t in layer] + [ + ( + t.view(len(batch), -1, *t.shape[-2:]) + if isinstance(t, torch.Tensor) + else t + ) + for t in layer + ] for layer in batch.past_key_values ] elif len(batch.past_key_values[0][0].shape) == 3: @@ -569,83 +584,98 @@ class IdeficsCausalLMBatch(Batch): # Iterate over attention layers # Concatenate past key values layer by layer to allow incremental garbage collection for j in range(len(first_past_kvs)): - _, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape - if seqlen > max_input_length: - # XXX: This is probably a cross attention key value - # If not this is ok - _padded_past_keys_shape = ( - total_batch_size, - _num_heads, - seqlen, - _head_dim, - ) + if any( + not isinstance(batch.past_key_values[j][0], torch.Tensor) + for batch in batches + ): + # XXX: Special handling for cross attention for mllama + padded_past_keys = [ + k for batch in batches for k in batch.past_key_values[j][0] + ] + padded_past_values = [ + k for batch in batches for k in batch.past_key_values[j][1] + ] + past_key_values.append([padded_past_keys, padded_past_values]) else: - _padded_past_keys_shape = padded_past_keys_shape - - padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape) - start_index = 0 - for batch in batches: - past_keys = batch.past_key_values[j][0] - # Clear reference to the original tensor - batch.past_key_values[j][0] = None - - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the keys to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if past_keys.shape[2] > past_seq_len: - # XXX: This is a cross attention kv in mllama - past_seq_len = past_keys.shape[2] - if batch.keys_head_dim_last: - padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( - past_keys[:, :, -past_seq_len:, :] + _, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape + if seqlen > max_input_length: + # XXX: This is probably a cross attention key value + # If not this is ok + _padded_past_keys_shape = ( + total_batch_size, + _num_heads, + seqlen, + _head_dim, ) else: - # BLOOM case - padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( - past_keys[:, :, :, -past_seq_len:] + _padded_past_keys_shape = padded_past_keys_shape + + padded_past_keys = first_past_kvs[j][0].new_zeros( + _padded_past_keys_shape + ) + start_index = 0 + for batch in batches: + past_keys = batch.past_key_values[j][0] + # Clear reference to the original tensor + batch.past_key_values[j][0] = None + + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the keys to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + if past_keys.shape[2] > past_seq_len: + # XXX: This is a cross attention kv in mllama + past_seq_len = past_keys.shape[2] + if batch.keys_head_dim_last: + padded_past_keys[ + start_index:end_index, :, -past_seq_len:, : + ] = past_keys[:, :, -past_seq_len:, :] + else: + # BLOOM case + padded_past_keys[ + start_index:end_index, :, :, -past_seq_len: + ] = past_keys[:, :, :, -past_seq_len:] + del past_keys + + start_index = end_index + + _, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape + if seqlen > max_input_length: + # XXX: This is probably a cross attention key value + # If not this is ok + _padded_past_values_shape = ( + total_batch_size, + _num_heads, + seqlen, + _head_dim, ) - del past_keys - - start_index = end_index - - _, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape - if seqlen > max_input_length: - # XXX: This is probably a cross attention key value - # If not this is ok - _padded_past_values_shape = ( - total_batch_size, - _num_heads, - seqlen, - _head_dim, + else: + _padded_past_values_shape = padded_past_values_shape + padded_past_values = first_past_kvs[j][1].new_zeros( + _padded_past_values_shape ) - else: - _padded_past_values_shape = padded_past_values_shape - padded_past_values = first_past_kvs[j][1].new_zeros( - _padded_past_values_shape - ) - start_index = 0 - for batch in batches: - past_values = batch.past_key_values[j][1] - # Clear reference to the original tensor - batch.past_key_values[j][1] = None + start_index = 0 + for batch in batches: + past_values = batch.past_key_values[j][1] + # Clear reference to the original tensor + batch.past_key_values[j][1] = None - # Slicing end index for this batch - end_index = start_index + len(batch) - # We slice the past values to remove the padding from previous batches - past_seq_len = batch.max_input_length - 1 - if past_values.shape[2] > past_seq_len: - # XXX: This is a cross attention kv in mllama - past_seq_len = past_values.shape[2] - padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( - past_values[:, :, -past_seq_len:, :] - ) - del past_values + # Slicing end index for this batch + end_index = start_index + len(batch) + # We slice the past values to remove the padding from previous batches + past_seq_len = batch.max_input_length - 1 + if past_values.shape[2] > past_seq_len: + # XXX: This is a cross attention kv in mllama + past_seq_len = past_values.shape[2] + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] + ) + del past_values - # Update values - start_index = end_index + # Update values + start_index = end_index - past_key_values.append([padded_past_keys, padded_past_values]) + past_key_values.append([padded_past_keys, padded_past_values]) return cls( batch_id=batches[0].batch_id,