diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 73325c88..ddb4e36d 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -22,9 +22,11 @@ from torch import nn from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "ipex": - pass + import intel_extension_for_pytorch as ipex else: - pass + import flash_attn_2_cuda + +import numpy as np from transformers.activations import ACT2FN import torch.nn.functional as F @@ -66,7 +68,7 @@ def apply_rotary_pos_emb_vision( return output -class Qwen2VLSdpaAttention(nn.Module): +class Qwen2VLAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() self.embed_dim = config.embed_dim // weights.process_group.size() @@ -88,13 +90,14 @@ class Qwen2VLSdpaAttention(nn.Module): weights=weights, bias=True, ) + self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads) def forward( self, hidden_state: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + max_seqlen: int, ) -> torch.Tensor: # apply the qkv linear layer to the hidden state qkv = self.qkv(hidden_state) @@ -117,37 +120,59 @@ class Qwen2VLSdpaAttention(nn.Module): 0 ) key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) - # TODO: make use of existing RotatoryPositionEmbedding class - # create the attention mask - attention_mask = torch.zeros( - [1, hidden_state.shape[0], hidden_state.shape[0]], - device=hidden_state.device, - dtype=torch.bool, - ) - # TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it + # calc maximum sequence length for any batch + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + causal = False - # apply the cu_seqlens to the attention mask - for i in range(1, len(cu_seqlens)): - attention_mask[ - ..., - cu_seqlens[i - 1] : cu_seqlens[i], - cu_seqlens[i - 1] : cu_seqlens[i], - ] = True + # execute flash attention + if SYSTEM == "ipex": + attn_output = torch.empty_like(query) + ipex.llm.functional.varlen_attention( + (query.contiguous() if query.device.type == "xpu" else query), + (key.contiguous() if key.device.type == "xpu" else key), + (value.contiguous() if value.device.type == "xpu" else value), + attn_output, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + attn_output = flash_attn_2_cuda.varlen_fwd( + query, + key, + value, + None, # tmp buffer (auto-allocated) + cu_seqlens, # cu_seqlens_q + cu_seqlens, # cu_seqlens_k + None, # max_seqlen_q (auto-computed) + None, # max_seqlen_k (auto-computed) + None, # block_tables + None, # broadcast_mask + max_seqlen, # max_seqlen + max_seqlen, # max_seqlen + 0.0, # dropout_p + self.softmax_scale, + False, # zero_tensors + causal, # causal attention within each sequence + -1, # window_size_left + -1, # window_size_right + 0.0, # softmax_cap + False, # deterministic + None, # rng_state + )[0] - # transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim) - query = query.transpose(0, 1) - key = key.transpose(0, 1) - value = value.transpose(0, 1) - - # apply attention - attn_output = F.scaled_dot_product_attention( - query, key, value, attention_mask, dropout_p=0.0 - ) - attn_output = attn_output.transpose(0, 1) + # reshape output to original dimensions attn_output = attn_output.reshape(hidden_state.shape[0], -1) - # TODO: prefer flash attention - attn_output = self.proj(attn_output) return attn_output @@ -173,7 +198,7 @@ class Qwen2VLVisionMLP(nn.Module): class Qwen2VLVisionBlock(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.attn = Qwen2VLSdpaAttention( + self.attn = Qwen2VLAttention( prefix=f"{prefix}.attn", config=config, weights=weights, @@ -194,10 +219,12 @@ class Qwen2VLVisionBlock(nn.Module): weights=weights, ) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: hidden_states_post_norm1, res = self.norm1(hidden_states) hidden_states = hidden_states + self.attn( - hidden_states_post_norm1, cu_seqlens, rotary_pos_emb + hidden_states_post_norm1, cu_seqlens, rotary_pos_emb, max_seqlen ) hidden_states_post_norm2, res = self.norm2(hidden_states) hidden_states = hidden_states + self.mlp(hidden_states_post_norm2) @@ -220,7 +247,7 @@ class Qwen2VLPatchMerger(nn.Module): prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True ) - def forward(self, hidden_states, grid_thw) -> torch.Tensor: + def forward(self, hidden_states) -> torch.Tensor: hidden_states, _ = self.patch_merger_ln_q(hidden_states) hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = self.fc1(hidden_states) @@ -281,7 +308,6 @@ class Qwen2VisionModel(nn.Module): def forward( self, pixel_values: torch.Tensor, - aspect_ratio_ids: Optional[torch.Tensor] = None, grid_thw: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # reshape the input tensor for processing @@ -336,13 +362,13 @@ class Qwen2VisionModel(nn.Module): grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states for block in self.blocks: - hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb) + hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) # apply the final patch merger to the hidden states - hidden_states = self.merger(hidden_states, grid_thw) + hidden_states = self.merger(hidden_states) return hidden_states