feat: support flash attention 2 in qwen2 vl vision blocks (#2721)
* feat: support flash attention 2 in qwen2 vl vision blocks * fix: calc max_seqlen once and small refactors
This commit is contained in:
parent
3c9df21ff8
commit
38cff84a3e
|
@ -22,9 +22,11 @@ from torch import nn
|
|||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
pass
|
||||
import intel_extension_for_pytorch as ipex
|
||||
else:
|
||||
pass
|
||||
import flash_attn_2_cuda
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
import torch.nn.functional as F
|
||||
|
@ -66,7 +68,7 @@ def apply_rotary_pos_emb_vision(
|
|||
return output
|
||||
|
||||
|
||||
class Qwen2VLSdpaAttention(nn.Module):
|
||||
class Qwen2VLAttention(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.embed_dim // weights.process_group.size()
|
||||
|
@ -88,13 +90,14 @@ class Qwen2VLSdpaAttention(nn.Module):
|
|||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
max_seqlen: int,
|
||||
) -> torch.Tensor:
|
||||
# apply the qkv linear layer to the hidden state
|
||||
qkv = self.qkv(hidden_state)
|
||||
|
@ -117,37 +120,59 @@ class Qwen2VLSdpaAttention(nn.Module):
|
|||
0
|
||||
)
|
||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
# TODO: make use of existing RotatoryPositionEmbedding class
|
||||
|
||||
# create the attention mask
|
||||
attention_mask = torch.zeros(
|
||||
[1, hidden_state.shape[0], hidden_state.shape[0]],
|
||||
device=hidden_state.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
# TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
causal = False
|
||||
|
||||
# apply the cu_seqlens to the attention mask
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[
|
||||
...,
|
||||
cu_seqlens[i - 1] : cu_seqlens[i],
|
||||
cu_seqlens[i - 1] : cu_seqlens[i],
|
||||
] = True
|
||||
# execute flash attention
|
||||
if SYSTEM == "ipex":
|
||||
attn_output = torch.empty_like(query)
|
||||
ipex.llm.functional.varlen_attention(
|
||||
(query.contiguous() if query.device.type == "xpu" else query),
|
||||
(key.contiguous() if key.device.type == "xpu" else key),
|
||||
(value.contiguous() if value.device.type == "xpu" else value),
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None, # tmp buffer (auto-allocated)
|
||||
cu_seqlens, # cu_seqlens_q
|
||||
cu_seqlens, # cu_seqlens_k
|
||||
None, # max_seqlen_q (auto-computed)
|
||||
None, # max_seqlen_k (auto-computed)
|
||||
None, # block_tables
|
||||
None, # broadcast_mask
|
||||
max_seqlen, # max_seqlen
|
||||
max_seqlen, # max_seqlen
|
||||
0.0, # dropout_p
|
||||
self.softmax_scale,
|
||||
False, # zero_tensors
|
||||
causal, # causal attention within each sequence
|
||||
-1, # window_size_left
|
||||
-1, # window_size_right
|
||||
0.0, # softmax_cap
|
||||
False, # deterministic
|
||||
None, # rng_state
|
||||
)[0]
|
||||
|
||||
# transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim)
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
value = value.transpose(0, 1)
|
||||
|
||||
# apply attention
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attention_mask, dropout_p=0.0
|
||||
)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
# reshape output to original dimensions
|
||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||
# TODO: prefer flash attention
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
@ -173,7 +198,7 @@ class Qwen2VLVisionMLP(nn.Module):
|
|||
class Qwen2VLVisionBlock(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.attn = Qwen2VLSdpaAttention(
|
||||
self.attn = Qwen2VLAttention(
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
|
@ -194,10 +219,12 @@ class Qwen2VLVisionBlock(nn.Module):
|
|||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
||||
def forward(
|
||||
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||
) -> torch.Tensor:
|
||||
hidden_states_post_norm1, res = self.norm1(hidden_states)
|
||||
hidden_states = hidden_states + self.attn(
|
||||
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb
|
||||
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||
)
|
||||
hidden_states_post_norm2, res = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
|
||||
|
@ -220,7 +247,7 @@ class Qwen2VLPatchMerger(nn.Module):
|
|||
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, grid_thw) -> torch.Tensor:
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
|
@ -281,7 +308,6 @@ class Qwen2VisionModel(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||
grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# reshape the input tensor for processing
|
||||
|
@ -336,13 +362,13 @@ class Qwen2VisionModel(nn.Module):
|
|||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||
# iterately apply the blocks to the hidden states
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb)
|
||||
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states, grid_thw)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue