feat: support flash attention 2 in qwen2 vl vision blocks (#2721)

* feat: support flash attention 2 in qwen2 vl vision blocks

* fix: calc max_seqlen once and small refactors
This commit is contained in:
drbh 2024-11-18 12:46:40 -05:00 committed by GitHub
parent 3c9df21ff8
commit 38cff84a3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 65 additions and 39 deletions

View File

@ -22,9 +22,11 @@ from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex": if SYSTEM == "ipex":
pass import intel_extension_for_pytorch as ipex
else: else:
pass import flash_attn_2_cuda
import numpy as np
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
import torch.nn.functional as F import torch.nn.functional as F
@ -66,7 +68,7 @@ def apply_rotary_pos_emb_vision(
return output return output
class Qwen2VLSdpaAttention(nn.Module): class Qwen2VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights): def __init__(self, *, prefix, config, weights):
super().__init__() super().__init__()
self.embed_dim = config.embed_dim // weights.process_group.size() self.embed_dim = config.embed_dim // weights.process_group.size()
@ -88,13 +90,14 @@ class Qwen2VLSdpaAttention(nn.Module):
weights=weights, weights=weights,
bias=True, bias=True,
) )
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
def forward( def forward(
self, self,
hidden_state: torch.Tensor, hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, rotary_pos_emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, max_seqlen: int,
) -> torch.Tensor: ) -> torch.Tensor:
# apply the qkv linear layer to the hidden state # apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state) qkv = self.qkv(hidden_state)
@ -117,37 +120,59 @@ class Qwen2VLSdpaAttention(nn.Module):
0 0
) )
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# TODO: make use of existing RotatoryPositionEmbedding class
# create the attention mask # calc maximum sequence length for any batch
attention_mask = torch.zeros( query = query.contiguous()
[1, hidden_state.shape[0], hidden_state.shape[0]], key = key.contiguous()
device=hidden_state.device, value = value.contiguous()
dtype=torch.bool, causal = False
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
) )
# TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,
key,
value,
None, # tmp buffer (auto-allocated)
cu_seqlens, # cu_seqlens_q
cu_seqlens, # cu_seqlens_k
None, # max_seqlen_q (auto-computed)
None, # max_seqlen_k (auto-computed)
None, # block_tables
None, # broadcast_mask
max_seqlen, # max_seqlen
max_seqlen, # max_seqlen
0.0, # dropout_p
self.softmax_scale,
False, # zero_tensors
causal, # causal attention within each sequence
-1, # window_size_left
-1, # window_size_right
0.0, # softmax_cap
False, # deterministic
None, # rng_state
)[0]
# apply the cu_seqlens to the attention mask # reshape output to original dimensions
for i in range(1, len(cu_seqlens)):
attention_mask[
...,
cu_seqlens[i - 1] : cu_seqlens[i],
cu_seqlens[i - 1] : cu_seqlens[i],
] = True
# transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
# apply attention
attn_output = F.scaled_dot_product_attention(
query, key, value, attention_mask, dropout_p=0.0
)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(hidden_state.shape[0], -1) attn_output = attn_output.reshape(hidden_state.shape[0], -1)
# TODO: prefer flash attention
attn_output = self.proj(attn_output) attn_output = self.proj(attn_output)
return attn_output return attn_output
@ -173,7 +198,7 @@ class Qwen2VLVisionMLP(nn.Module):
class Qwen2VLVisionBlock(nn.Module): class Qwen2VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.attn = Qwen2VLSdpaAttention( self.attn = Qwen2VLAttention(
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
config=config, config=config,
weights=weights, weights=weights,
@ -194,10 +219,12 @@ class Qwen2VLVisionBlock(nn.Module):
weights=weights, weights=weights,
) )
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
hidden_states_post_norm1, res = self.norm1(hidden_states) hidden_states_post_norm1, res = self.norm1(hidden_states)
hidden_states = hidden_states + self.attn( hidden_states = hidden_states + self.attn(
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb hidden_states_post_norm1, cu_seqlens, rotary_pos_emb, max_seqlen
) )
hidden_states_post_norm2, res = self.norm2(hidden_states) hidden_states_post_norm2, res = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2) hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
@ -220,7 +247,7 @@ class Qwen2VLPatchMerger(nn.Module):
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
) )
def forward(self, hidden_states, grid_thw) -> torch.Tensor: def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states) hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states) hidden_states = self.fc1(hidden_states)
@ -281,7 +308,6 @@ class Qwen2VisionModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
aspect_ratio_ids: Optional[torch.Tensor] = None,
grid_thw: Optional[torch.LongTensor] = None, grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# reshape the input tensor for processing # reshape the input tensor for processing
@ -336,13 +362,13 @@ class Qwen2VisionModel(nn.Module):
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32) ).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states # iterately apply the blocks to the hidden states
for block in self.blocks: for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb) hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
# apply the final patch merger to the hidden states # apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states, grid_thw) hidden_states = self.merger(hidden_states)
return hidden_states return hidden_states