diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index f7f823fc..1085075e 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -28,11 +28,17 @@ class ToolCall(BaseModel): function: dict +class Chunk(BaseModel): + type: str + text: Optional[str] = None + image_url: Any = None + + class Message(BaseModel): # Role of the message sender role: str # Content of the message - content: Optional[str] = None + content: Optional[Union[str, List[Chunk]]] = None # Optional name of the message sender name: Optional[str] = None # Tool calls associated with the chat completion diff --git a/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json b/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json new file mode 100644 index 00000000..eb8b82b2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json @@ -0,0 +1,106 @@ +[ + { + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1727097740, + "id": "", + "model": "s0409/model-3", + "object": "chat.completion", + "system_fingerprint": "2.2.1-dev0-native", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 24, + "total_tokens": 44 + } + }, + { + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1727097740, + "id": "", + "model": "s0409/model-3", + "object": "chat.completion", + "system_fingerprint": "2.2.1-dev0-native", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 24, + "total_tokens": 44 + } + }, + { + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1727097740, + "id": "", + "model": "s0409/model-3", + "object": "chat.completion", + "system_fingerprint": "2.2.1-dev0-native", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 24, + "total_tokens": 44 + } + }, + { + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1727097740, + "id": "", + "model": "s0409/model-3", + "object": "chat.completion", + "system_fingerprint": "2.2.1-dev0-native", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 24, + "total_tokens": 44 + } + } +] diff --git a/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json b/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json new file mode 100644 index 00000000..4000691f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1727090615, + "id": "", + "model": "s0409/model-3", + "object": "chat.completion", + "system_fingerprint": "2.2.1-dev0-native", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 24, + "total_tokens": 44 + } +} diff --git a/integration-tests/models/test_mllama.py b/integration-tests/models/test_mllama.py new file mode 100644 index 00000000..b77f09e0 --- /dev/null +++ b/integration-tests/models/test_mllama.py @@ -0,0 +1,108 @@ +import pytest +import base64 +import asyncio + + +@pytest.fixture(scope="module") +def mllama_handle(launcher): + with launcher("s0409/model-3", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def mllama(mllama_handle): + await mllama_handle.health(300) + return mllama_handle.client + + +# TODO fix the server parsser to count inline image tokens correctly +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.mark.asyncio +async def test_mllama_simpl(mllama, response_snapshot): + # chicken = get_chicken() + response = await mllama.chat( + max_tokens=20, + temperature=0.0, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you tell me a very short story based on the image?", + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png" + }, + }, + ], + }, + ], + ) + + assert response.usage == { + "completion_tokens": 20, + "prompt_tokens": 24, + "total_tokens": 44, + } + assert ( + response.choices[0].message.content + == "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak" + ) + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +async def test_mllama_load(mllama, generate_load, response_snapshot): + futures = [ + mllama.chat( + max_tokens=20, + temperature=0.0, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you tell me a very short story based on the image?", + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png" + }, + }, + ], + }, + ], + ) + for i in range(4) + ] + responses = await asyncio.gather(*futures) + + generated_texts = [response.choices[0].message.content for response in responses] + + assert ( + generated_texts[0] + == "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak" + ) + assert len(generated_texts) == 4 + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) + + assert responses == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index 5d0be9c8..f78c3772 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -146,6 +146,7 @@ pub enum Config { ClipVisionModel(ClipVisionModel), Mistral, Idefics, + Mllama, Idefics2(Idefics2), Ssm, GptBigcode, diff --git a/router/src/validation.rs b/router/src/validation.rs index 92491d88..85b4220b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -567,6 +567,7 @@ fn image_tokens( use HubPreprocessorConfig::*; match config { Idefics => "".to_string(), + Mllama => "<|image|>".to_string(), Idefics2(config) => { const FAKE: &str = ""; const IMAGE: &str = ""; @@ -618,7 +619,7 @@ fn prepare_input( use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { - Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { + Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index ee06b5e7..8e642829 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch Mllama model.""" -from typing import Optional, Tuple +from typing import Optional, Tuple, List import torch import torch.utils.checkpoint @@ -22,12 +22,18 @@ from torch import nn import math from transformers.activations import ACT2FN +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_outputs import ( + CausalLMOutputWithPast, + BaseModelOutputWithPast, +) +from transformers.cache_utils import ( + StaticCache, + DynamicCache, +) +from transformers.modeling_attn_mask_utils import AttentionMaskConverter import torch.nn.functional as F -from text_generation_server.layers.layernorm import ( - FastRMSNorm, -) -from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -37,6 +43,185 @@ from text_generation_server.layers import ( ) +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape( + batch_size, max_num_tiles * target_length, 1 + ) + attention_mask = ( + attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min + ) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave( + num_vision_tokens, dim=3 + ) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value) + .any(dim=-1) + .type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision class MllamaVisionMLP(nn.Module): def __init__(self, *, prefix, config, weights): @@ -62,8 +247,8 @@ class MllamaVisionSdpaAttention(nn.Module): super().__init__() self.embed_dim = config.hidden_size - self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads + self.num_heads = config.attention_heads // weights.process_group.size() self.qkv_proj = TensorParallelColumnLinear.load_multi( config, @@ -87,11 +272,11 @@ class MllamaVisionSdpaAttention(nn.Module): qkv = self.qkv_proj(hidden_state) query, key, value = qkv.split( [ - self.head_size * self.num_heads, - self.head_size * self.num_heads, - self.head_size * self.num_heads, + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, ], - dim=1, + dim=2, ) batch_size, q_seq_len, _ = query.shape @@ -145,7 +330,7 @@ class MllamaVisionEncoderLayer(nn.Module): weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False ) self.gate_ffn = nn.Parameter( - weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False + weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False ) def forward( @@ -156,9 +341,7 @@ class MllamaVisionEncoderLayer(nn.Module): # Self Attention residual = hidden_state hidden_state = self.input_layernorm(hidden_state) - hidden_state, attn_weights = self.self_attn( - hidden_state, attention_mask=attention_mask - ) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() hidden_state = residual + gate_attn * hidden_state @@ -190,15 +373,17 @@ class MllamaVisionEncoder(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): + encoder_states = [hidden_states] for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, attention_mask, ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs + encoder_states.append(hidden_states) - return hidden_states + return hidden_states, encoder_states class MllamaPrecomputedAspectRatioEmbedding(nn.Module): @@ -237,12 +422,15 @@ class MllamaPrecomputedPositionEmbedding(nn.Module): self.hidden_size = config.hidden_size self.scale = config.hidden_size**-0.5 - self.gate = nn.Parameter(torch.zeros(1)) + self.gate = nn.Parameter( + weights.get_tensor(f"{prefix}.gate"), requires_grad=False + ) # position embedding - self.embedding = nn.Parameter( + embedding = nn.Parameter( weights.get_tensor(f"{prefix}.embedding"), requires_grad=False ) + self.gated_position_embedding = (1 - self.gate.tanh()) * embedding self.tile_embedding = TensorParallelEmbedding( prefix=f"{prefix}.tile_embedding", weights=weights ) @@ -251,8 +439,7 @@ class MllamaPrecomputedPositionEmbedding(nn.Module): self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor ) -> torch.Tensor: # position embeddings - gated_position_embedding = (1 - self.gate.tanh()) * self.embedding - hidden_state = hidden_state + gated_position_embedding.view( + hidden_state = hidden_state + self.gated_position_embedding.view( 1, 1, self.num_patches, self.hidden_size ) @@ -280,6 +467,7 @@ class MllamaVisionModel(nn.Module): self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.scale = config.hidden_size**-0.5 + self.dtype = weights.dtype self.patch_embedding = nn.Conv2d( in_channels=config.in_channels, @@ -368,7 +556,7 @@ class MllamaVisionModel(nn.Module): ) # patch embedding - patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) + patch_embeds = self.patch_embedding(pixel_values) hidden_state = patch_embeds.flatten(2).transpose(1, 2) # tile embeddings @@ -421,12 +609,10 @@ class MllamaVisionModel(nn.Module): ) hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) - output = self.transformer( + hidden_state, all_intermediate_hidden_states = self.transformer( hidden_state, attention_mask=attention_mask, - output_hidden_states=True, ) - hidden_state, all_intermediate_hidden_states = output[0], output[1] intermediate_hidden_states = [ hidden_state for idx, hidden_state in enumerate(all_intermediate_hidden_states) @@ -450,9 +636,9 @@ class MllamaVisionModel(nn.Module): num_tiles * (num_patches + num_padding_patches), dim, ) - hidden_state = self.global_transformer( + hidden_state, _ = self.global_transformer( hidden_state, attention_mask=attention_mask - )[0] + ) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, @@ -482,7 +668,7 @@ class MllamaVisionModel(nn.Module): class MllamaTextCrossAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, *, prefix, config, weights): + def __init__(self, *, prefix, config, weights, layer_idx): super().__init__() self.config = config self.num_heads = self.config.num_attention_heads @@ -491,11 +677,28 @@ class MllamaTextCrossAttention(nn.Module): self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.layer_idx = layer_idx - self.qkv_proj = TensorParallelColumnLinear.load_multi( + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) + + self.q_proj = TensorParallelColumnLinear.load( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=False, + ) + self.k_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=False, + ) + self.v_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.v_proj", weights=weights, bias=False, ) @@ -506,10 +709,10 @@ class MllamaTextCrossAttention(nn.Module): bias=False, ) - self.q_norm = FastRMSNorm.load( + self.q_norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps ) - self.k_norm = FastRMSNorm.load( + self.k_norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps ) @@ -519,8 +722,6 @@ class MllamaTextCrossAttention(nn.Module): cross_attention_states: Optional[torch.Tensor] = None, past_key_value=None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = None, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -544,6 +745,7 @@ class MllamaTextCrossAttention(nn.Module): value_states = repeat_kv(value_states, self.num_key_value_groups) key_states = self.k_norm(key_states) + if past_key_value is not None: # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! @@ -553,6 +755,7 @@ class MllamaTextCrossAttention(nn.Module): self.layer_idx, {"cache_position": cache_position}, ) + elif cache_position[0] != 0: key_states, value_states = ( past_key_value.key_cache[self.layer_idx], @@ -582,9 +785,6 @@ class MllamaTextCrossAttention(nn.Module): attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value @@ -594,7 +794,9 @@ class MllamaTextMLP(nn.Module): super().__init__() self.config = config self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], @@ -611,21 +813,28 @@ class MllamaTextMLP(nn.Module): self.act_fn = ACT2FN[config.hidden_activation] def forward(self, x): + shape = x.shape gate_up_states = self.gate_up_proj(x) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) + result = self.down_proj( + self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1] + ) + return result class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" - def __init__(self, *, prefix, config, weights) -> None: + def __init__(self, *, prefix, config, weights, layer_idx) -> None: super().__init__() self.cross_attn = MllamaTextCrossAttention( - prefix=f"{prefix}.cross_attn", config=config, weights=weights + prefix=f"{prefix}.cross_attn", + config=config, + weights=weights, + layer_idx=layer_idx, ) - self.input_layernorm = FastRMSNorm.load( + self.input_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) self.cross_attn_attn_gate = torch.nn.Parameter( @@ -633,7 +842,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ) self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.post_attention_layernorm = FastRMSNorm.load( + self.post_attention_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, @@ -650,8 +859,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): attention_mask: torch.Tensor, full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], past_key_value=None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> torch.Tensor: @@ -663,7 +870,6 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, - output_attentions=output_attentions, cache_position=cache_position, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states @@ -675,19 +881,11 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - if use_cache: - outputs += (past_key_value,) - - return outputs + return hidden_states class MllamaTextSelfAttention(nn.Module): - def __init__(self, *, prefix, config, weights): + def __init__(self, *, prefix, config, weights, layer_idx): super().__init__() self.config = config self.num_heads = config.num_attention_heads @@ -697,6 +895,12 @@ class MllamaTextSelfAttention(nn.Module): self.head_dim = config.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + self.num_key_value_heads // weights.process_group.size() + ) + self.layer_idx = layer_idx + self.qkv_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], @@ -716,17 +920,21 @@ class MllamaTextSelfAttention(nn.Module): hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, past_key_value=None, cache_position=None, **kwargs, ): bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + qkv = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_key_value_heads, + self.head_dim * self.num_key_value_heads, + ], + dim=2, + ) query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim @@ -772,7 +980,8 @@ class MllamaTextSelfAttention(nn.Module): query_states, key_states, value_states, - attn_mask=causal_mask, + # TODO + # attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) @@ -784,21 +993,49 @@ class MllamaTextSelfAttention(nn.Module): return attn_output, None, past_key_value +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class MllamaTextRMSNorm(nn.Module): + def __init__(self, weight, eps): + super().__init__() + self.weight = weight + self.variance_epsilon = eps + + @classmethod + def load(cls, *, prefix, weights, eps): + weight = nn.Parameter( + weights.get_tensor(f"{prefix}.weight"), requires_grad=False + ) + return cls(weight=weight, eps=eps) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT class MllamaSelfAttentionDecoderLayer(nn.Module): - def __init__(self, *, prefix, config, weights): + def __init__(self, *, prefix, config, weights, layer_idx): super().__init__() self.hidden_size = config.hidden_size self.self_attn = MllamaTextSelfAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_idx=layer_idx, ) self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = FastRMSNorm.load( + self.input_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) - self.post_attention_layernorm = FastRMSNorm.load( + self.post_attention_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, @@ -810,8 +1047,6 @@ class MllamaSelfAttentionDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value=None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] @@ -820,28 +1055,6 @@ class MllamaSelfAttentionDecoderLayer(nn.Module): ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -852,8 +1065,6 @@ class MllamaSelfAttentionDecoderLayer(nn.Module): attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, @@ -866,15 +1077,82 @@ class MllamaSelfAttentionDecoderLayer(nn.Module): hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) + return hidden_states - if output_attentions: - outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) +class MllamaRotaryEmbedding(nn.Module): + def __init__( + self, + *, + config, + weights, + ): + super().__init__() + device = weights.device + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - return outputs + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + inv_freq.to(device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len + and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( + 1, 2 + ) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class MllamaTextModel(nn.Module): @@ -882,6 +1160,7 @@ class MllamaTextModel(nn.Module): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.config = config self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.embed_tokens", weights=weights ) @@ -895,6 +1174,7 @@ class MllamaTextModel(nn.Module): prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, + layer_idx=layer_idx, ) ) else: @@ -903,24 +1183,20 @@ class MllamaTextModel(nn.Module): prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, + layer_idx=layer_idx, ) ) # TODO Should we use this slow norm ? # self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.norm = FastRMSNorm.load( + self.norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) # TODO Anything specific ? head_size = config.hidden_size // config.num_attention_heads - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=head_size, - base=config.rope_theta, - device=weights.device, - ) + self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights) def forward( self, @@ -958,12 +1234,12 @@ class MllamaTextModel(nn.Module): if position_ids is None: position_ids = cache_position.unsqueeze(0) - # causal_mask = self._update_causal_mask( - # attention_mask, - # inputs_embeds, - # cache_position, - # past_key_values, - # ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + ) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -1000,82 +1276,81 @@ class MllamaTextModel(nn.Module): hidden_states = self.norm(hidden_states) - return hidden_states + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) - # def _update_causal_mask( - # self, - # attention_mask: torch.Tensor, - # input_tensor: torch.Tensor, - # cache_position: torch.Tensor, - # past_key_values, - # ): - # if self.config._attn_implementation == "flash_attention_2": - # if attention_mask is not None and 0.0 in attention_mask: - # return attention_mask - # return None + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None - # # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # # to infer the attention mask. - # past_seen_tokens = ( - # past_key_values.get_seq_length() if past_key_values is not None else 0 - # ) - # using_static_cache = isinstance(past_key_values, StaticCache) + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + using_static_cache = isinstance(past_key_values, StaticCache) - # # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - # # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line - # # self.config._attn_implementation == "sdpa" and - # if ( - # self.config._attn_implementation == "sdpa" - # and not using_static_cache - # and not output_attentions - # ): - # if AttentionMaskConverter._ignore_causal_mask_sdpa( - # attention_mask, - # inputs_embeds=input_tensor, - # past_key_values_length=past_seen_tokens, - # is_training=self.training, - # ): - # return None + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line + # self.config._attn_implementation == "sdpa" and + # if self.config._attn_implementation == "sdpa" and not using_static_cache: + if self.config._attn_implementation == "sdpa" and not using_static_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None - # dtype, device = input_tensor.dtype, input_tensor.device - # min_dtype = torch.finfo(dtype).min - # sequence_length = input_tensor.shape[1] - # if using_static_cache: - # target_length = past_key_values.get_max_length() - # else: - # target_length = ( - # attention_mask.shape[-1] - # if isinstance(attention_mask, torch.Tensor) - # else past_seen_tokens + sequence_length + 1 - # ) + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) - # # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - # causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - # attention_mask, - # sequence_length=sequence_length, - # target_length=target_length, - # dtype=dtype, - # device=device, - # min_dtype=min_dtype, - # cache_position=cache_position, - # batch_size=input_tensor.shape[0], - # ) + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) - # if ( - # self.config._attn_implementation == "sdpa" - # and attention_mask is not None - # and attention_mask.device.type == "cuda" - # and not output_attentions - # ): - # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # # Details: https://github.com/pytorch/pytorch/issues/110213 - # causal_mask = AttentionMaskConverter._unmask_unattended( - # causal_mask, min_dtype - # ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) - # return causal_mask + return causal_mask class MllamaForCausalLM(nn.Module): @@ -1104,11 +1379,11 @@ class MllamaForCausalLM(nn.Module): past_key_values=None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ): # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # TODO outputs = self.model( input_ids=input_ids, cross_attention_states=cross_attention_states, @@ -1118,15 +1393,20 @@ class MllamaForCausalLM(nn.Module): full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - use_cache=use_cache, cache_position=cache_position, ) - hidden_states = outputs + hidden_states = outputs.last_hidden_state # if lm_head_indices is not None: # hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) - return logits + return ( + CausalLMOutputWithPast( + logits=logits, + past_key_values=outputs.past_key_values, + ), + speculative_logits, + ) def prepare_inputs_for_generation( self, @@ -1136,7 +1416,6 @@ class MllamaForCausalLM(nn.Module): inputs_embeds=None, cache_position=None, position_ids=None, - use_cache=True, num_logits_to_keep=None, **kwargs, ): @@ -1201,7 +1480,6 @@ class MllamaForCausalLM(nn.Module): "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": use_cache, "attention_mask": attention_mask, } ) @@ -1215,6 +1493,12 @@ class MllamaForConditionalGeneration(nn.Module): config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator + # TODO check how this is determined + config.text_config._attn_implementation = "sdpa" + # self.hidden_size = ( + # config.text_config.hidden_size // weights.process_group.size() + # ) + self.hidden_size = config.text_config.hidden_size self.vision_model = MllamaVisionModel( prefix="vision_model", config=config.vision_config, weights=weights ) @@ -1225,3 +1509,92 @@ class MllamaForConditionalGeneration(nn.Module): prefix="multi_modal_projector", config=config, weights=weights, bias=True ) self.config = config + self.dtype = weights.dtype + self.device = weights.device + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[List[List[int]]] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[List[List[List[int]]]] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + image_hidden_states=None, + image_attention_mask=None, + ): + if past_key_values is None: + past_key_values = DynamicCache( + num_hidden_layers=self.config.text_config.num_hidden_layers + ) + elif isinstance(past_key_values, list): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError( + "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" + ) + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # get vision tokens from vision model + + vision_states = self.vision_model( + pixel_values, aspect_ratio_ids, aspect_ratio_mask + ) + cross_attention_states = self.multi_modal_projector(vision_states).reshape( + -1, vision_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = ( + _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[ + :, :, cache_position + ] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 06cd501e..e7fccb74 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -98,6 +98,8 @@ class IDEFICSSharded(IdeficsCausalLM): else: raise RuntimeError(f"Unsupported model type {config.model_type}") + self.config = config + torch.distributed.barrier(group=self.process_group) super(IdeficsCausalLM, self).__init__( model_id=model_id, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index e5c862cc..a82570db 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -6,8 +6,6 @@ import time from dataclasses import dataclass from opentelemetry import trace from transformers import ( - AutoProcessor, - AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin, ) @@ -38,6 +36,9 @@ class IdeficsCausalLMBatch(Batch): attention_mask: torch.Tensor position_ids: torch.Tensor pixel_values: Optional[torch.Tensor] + aspect_ratio_ids: Optional[torch.Tensor] + aspect_ratio_mask: Optional[torch.Tensor] + cross_attention_mask: Optional[torch.Tensor] image_hidden_states: Optional[torch.Tensor] image_attention_mask: Optional[torch.Tensor] past_key_values: Optional[List[Tuple]] @@ -164,7 +165,7 @@ class IdeficsCausalLMBatch(Batch): image = Image.open(BytesIO(chunk.image.data)) curr_images.append(image) # TODO unsure about BOS - curr_text += "<|image|><|begin_of_text|>" + curr_text += "<|image|>" else: raise RuntimeError(f"Invalid chunk type {chunk_type}") images.append(curr_images) @@ -173,6 +174,8 @@ class IdeficsCausalLMBatch(Batch): # The processor replaces the call to tokenizer, and # a/ takes care of fetching images from the URL # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model + if all(len(im) == 0 for im in images): + images = None tokenized_inputs = processor( images=images, text=texts, @@ -205,7 +208,10 @@ class IdeficsCausalLMBatch(Batch): # Do the same for image_attention_mask if pixel_values is None: image_attention_mask = None - else: + aspect_ratio_ids = None + aspect_ratio_mask = None + cross_attention_mask = None + elif "image_attention_mask" in tokenized_inputs: image_attention_mask = input_ids.new_zeros( ( pb.size, @@ -216,6 +222,19 @@ class IdeficsCausalLMBatch(Batch): image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ "image_attention_mask" ] + aspect_ratio_ids = None + aspect_ratio_mask = None + cross_attention_mask = None + else: + image_attention_mask = None + aspect_ratio_ids = tokenized_inputs["aspect_ratio_ids"] + aspect_ratio_mask = tokenized_inputs["aspect_ratio_mask"] + cross_attention_mask = tokenized_inputs["cross_attention_mask"] + pixel_values = pixel_values.to(dtype=dtype) + # XXX: <|image|> token is actually out of bounds and bugs out the logit processors. + tokenized_inputs["input_ids"] = tokenized_inputs["input_ids"].clamp( + max=processor.tokenizer.vocab_size - 1 + ) position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) @@ -245,6 +264,9 @@ class IdeficsCausalLMBatch(Batch): max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + cross_attention_mask=cross_attention_mask, ) @tracer.start_as_current_span("filter") @@ -308,15 +330,21 @@ class IdeficsCausalLMBatch(Batch): + new_padding_right_offset, ] # Do the same for pixel_values and image_attention_mask - pixel_values = self.pixel_values[keep_indices] - self.image_attention_mask = self.image_attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.image_attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - :, - ] + if self.pixel_values is not None: + pixel_values = self.pixel_values[keep_indices] + else: + pixel_values = None + + if self.image_attention_mask is not None: + self.image_attention_mask = self.image_attention_mask[ + keep_indices, + -(self.padding_right_offset + max_input_length) : ( + self.image_attention_mask.shape[1] - self.padding_right_offset + ) + + new_padding_right_offset, + :, + ] + if self.image_hidden_states is None: image_hidden_states = None else: @@ -359,6 +387,9 @@ class IdeficsCausalLMBatch(Batch): self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens + self.aspect_ratio_ids = None + self.aspect_ratio_mask = None + self.cross_attention_mask = None return self @@ -376,7 +407,8 @@ class IdeficsCausalLMBatch(Batch): for batch in batches: total_batch_size += len(batch) max_input_length = max(max_input_length, batch.max_input_length) - max_num_images = max(max_num_images, batch.pixel_values.size(1)) + if batch.pixel_values is not None: + max_num_images = max(max_num_images, batch.pixel_values.size(1)) padding_right_offset = max(padding_right_offset, batch.padding_right_offset) # Batch attributes @@ -439,16 +471,19 @@ class IdeficsCausalLMBatch(Batch): (total_batch_size, max_input_length + padding_right_offset), ) - curr_batch_max_num_images = batch.pixel_values.size(1) - if pixel_values is None: - pixel_values = batch.pixel_values.new_zeros( - (total_batch_size, max_num_images, 3, 224, 224) + if batch.pixel_values is not None: + curr_batch_max_num_images = batch.pixel_values.size(1) + if pixel_values is None: + pixel_values = batch.pixel_values.new_zeros( + (total_batch_size, max_num_images, 3, 224, 224) + ) + pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( + batch.pixel_values ) - pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( - batch.pixel_values - ) + else: + pixel_values = None - if image_attention_mask is None: + if image_attention_mask is None and batch.image_attention_mask is not None: image_attention_mask = batch.image_attention_mask.new_zeros( ( total_batch_size, @@ -472,13 +507,14 @@ class IdeficsCausalLMBatch(Batch): :, batch_left_offset : -batch.padding_right_offset, ] - image_attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, - :curr_batch_max_num_images, - ] = batch.image_attention_mask[ - :, batch_left_offset : -batch.padding_right_offset, : - ] + if batch.image_attention_mask is not None: + image_attention_mask[ + start_index:end_index, + left_offset:-padding_right_offset, + :curr_batch_max_num_images, + ] = batch.image_attention_mask[ + :, batch_left_offset : -batch.padding_right_offset, : + ] # Create empty tensor # position_ids is always of shape [batch_size, 1] @@ -531,7 +567,20 @@ class IdeficsCausalLMBatch(Batch): # Iterate over attention layers # Concatenate past key values layer by layer to allow incremental garbage collection for j in range(len(first_past_kvs)): - padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) + _, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape + if seqlen > max_input_length: + # XXX: This is probably a cross attention key value + # If not this is ok + _padded_past_keys_shape = ( + total_batch_size, + _num_heads, + seqlen, + _head_dim, + ) + else: + _padded_past_keys_shape = padded_past_keys_shape + + padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape) start_index = 0 for batch in batches: past_keys = batch.past_key_values[j][0] @@ -542,6 +591,9 @@ class IdeficsCausalLMBatch(Batch): end_index = start_index + len(batch) # We slice the keys to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 + if past_keys.shape[2] > past_seq_len: + # XXX: This is a cross attention kv in mllama + past_seq_len = past_keys.shape[2] if batch.keys_head_dim_last: padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( past_keys[:, :, -past_seq_len:, :] @@ -555,8 +607,20 @@ class IdeficsCausalLMBatch(Batch): start_index = end_index + _, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape + if seqlen > max_input_length: + # XXX: This is probably a cross attention key value + # If not this is ok + _padded_past_values_shape = ( + total_batch_size, + _num_heads, + seqlen, + _head_dim, + ) + else: + _padded_past_values_shape = padded_past_values_shape padded_past_values = first_past_kvs[j][1].new_zeros( - padded_past_values_shape + _padded_past_values_shape ) start_index = 0 for batch in batches: @@ -568,6 +632,9 @@ class IdeficsCausalLMBatch(Batch): end_index = start_index + len(batch) # We slice the past values to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 + if past_values.shape[2] > past_seq_len: + # XXX: This is a cross attention kv in mllama + past_seq_len = past_values.shape[2] padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( past_values[:, :, -past_seq_len:, :] ) @@ -599,6 +666,10 @@ class IdeficsCausalLMBatch(Batch): padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, max_tokens=max_tokens, + # No need to keep this around. for Mllamma + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, ) def __len__(self): @@ -606,77 +677,6 @@ class IdeficsCausalLMBatch(Batch): class IdeficsCausalLM(Model): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.quantize = quantize - from text_generation_server.models.custom_modeling.idefics_modeling import ( - IdeficsForVisionText2Text, - ) - - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.bfloat16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = IdeficsForVisionText2Text.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": ""}) - - super(IdeficsCausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - @property def batch_type(self) -> Type[IdeficsCausalLMBatch]: return IdeficsCausalLMBatch @@ -690,6 +690,9 @@ class IdeficsCausalLM(Model): image_hidden_states, image_attention_mask, past_key_values: Optional = None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { @@ -699,18 +702,23 @@ class IdeficsCausalLM(Model): "image_hidden_states": image_hidden_states, "image_attention_mask": image_attention_mask, "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, } if self.has_position_ids: kwargs["position_ids"] = position_ids + if aspect_ratio_ids is not None: + kwargs["aspect_ratio_ids"] = aspect_ratio_ids + if aspect_ratio_mask is not None: + kwargs["aspect_ratio_mask"] = aspect_ratio_mask + if cross_attention_mask is not None: + kwargs["cross_attention_mask"] = cross_attention_mask outputs, speculative_logits = self.model.forward(**kwargs) + assert outputs.past_key_values is not None return ( outputs.logits, speculative_logits, outputs.past_key_values, - outputs.image_hidden_states, + getattr(outputs, "image_hidden_states", None), ) @tracer.start_as_current_span("generate_token") @@ -745,9 +753,13 @@ class IdeficsCausalLM(Model): image_hidden_states=batch.image_hidden_states, image_attention_mask=image_attention_mask, past_key_values=batch.past_key_values, + aspect_ratio_ids=batch.aspect_ratio_ids, + aspect_ratio_mask=batch.aspect_ratio_mask, + cross_attention_mask=batch.cross_attention_mask, ) # Hardcoded remove image tokens - logits[:, 32000:32001] = torch.finfo(logits.dtype).min + if self.config.model_type == "idefics": + logits[:, 32000:32001] = torch.finfo(logits.dtype).min start_decode = time.time_ns() @@ -890,10 +902,13 @@ class IdeficsCausalLM(Model): batch.input_ids = batch.input_ids[:, :1] # Update attention_mask as we added a new token to input_ids - batch.attention_mask[:, -batch.padding_right_offset] = 1 - batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( - batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] - ) + # batch.attention_mask[:, -batch.padding_right_offset] = 1 + if batch.image_attention_mask is not None: + batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( + batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] + ) + if batch.cross_attention_mask is not None: + batch.cross_attention_mask = batch.cross_attention_mask[:, -1:] # Decrease right offset batch.padding_right_offset -= 1 @@ -903,7 +918,8 @@ class IdeficsCausalLM(Model): # Update past key values batch.past_key_values = past batch.image_hidden_states = image_hidden_states - + if self.model.config.model_type == "mllama": + batch.pixel_values = None forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns)