feat: support flash attention 2 in qwen2 vl vision blocks (#2721)
* feat: support flash attention 2 in qwen2 vl vision blocks * fix: calc max_seqlen once and small refactors
This commit is contained in:
parent
3c9df21ff8
commit
38cff84a3e
|
@ -22,9 +22,11 @@ from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
pass
|
import intel_extension_for_pytorch as ipex
|
||||||
else:
|
else:
|
||||||
pass
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -66,7 +68,7 @@ def apply_rotary_pos_emb_vision(
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLSdpaAttention(nn.Module):
|
class Qwen2VLAttention(nn.Module):
|
||||||
def __init__(self, *, prefix, config, weights):
|
def __init__(self, *, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.embed_dim // weights.process_group.size()
|
self.embed_dim = config.embed_dim // weights.process_group.size()
|
||||||
|
@ -88,13 +90,14 @@ class Qwen2VLSdpaAttention(nn.Module):
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_state: torch.Tensor,
|
hidden_state: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
max_seqlen: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# apply the qkv linear layer to the hidden state
|
# apply the qkv linear layer to the hidden state
|
||||||
qkv = self.qkv(hidden_state)
|
qkv = self.qkv(hidden_state)
|
||||||
|
@ -117,37 +120,59 @@ class Qwen2VLSdpaAttention(nn.Module):
|
||||||
0
|
0
|
||||||
)
|
)
|
||||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
# TODO: make use of existing RotatoryPositionEmbedding class
|
|
||||||
|
|
||||||
# create the attention mask
|
# calc maximum sequence length for any batch
|
||||||
attention_mask = torch.zeros(
|
query = query.contiguous()
|
||||||
[1, hidden_state.shape[0], hidden_state.shape[0]],
|
key = key.contiguous()
|
||||||
device=hidden_state.device,
|
value = value.contiguous()
|
||||||
dtype=torch.bool,
|
causal = False
|
||||||
)
|
|
||||||
# TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it
|
|
||||||
|
|
||||||
# apply the cu_seqlens to the attention mask
|
# execute flash attention
|
||||||
for i in range(1, len(cu_seqlens)):
|
if SYSTEM == "ipex":
|
||||||
attention_mask[
|
attn_output = torch.empty_like(query)
|
||||||
...,
|
ipex.llm.functional.varlen_attention(
|
||||||
cu_seqlens[i - 1] : cu_seqlens[i],
|
(query.contiguous() if query.device.type == "xpu" else query),
|
||||||
cu_seqlens[i - 1] : cu_seqlens[i],
|
(key.contiguous() if key.device.type == "xpu" else key),
|
||||||
] = True
|
(value.contiguous() if value.device.type == "xpu" else value),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None, # tmp buffer (auto-allocated)
|
||||||
|
cu_seqlens, # cu_seqlens_q
|
||||||
|
cu_seqlens, # cu_seqlens_k
|
||||||
|
None, # max_seqlen_q (auto-computed)
|
||||||
|
None, # max_seqlen_k (auto-computed)
|
||||||
|
None, # block_tables
|
||||||
|
None, # broadcast_mask
|
||||||
|
max_seqlen, # max_seqlen
|
||||||
|
max_seqlen, # max_seqlen
|
||||||
|
0.0, # dropout_p
|
||||||
|
self.softmax_scale,
|
||||||
|
False, # zero_tensors
|
||||||
|
causal, # causal attention within each sequence
|
||||||
|
-1, # window_size_left
|
||||||
|
-1, # window_size_right
|
||||||
|
0.0, # softmax_cap
|
||||||
|
False, # deterministic
|
||||||
|
None, # rng_state
|
||||||
|
)[0]
|
||||||
|
|
||||||
# transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim)
|
# reshape output to original dimensions
|
||||||
query = query.transpose(0, 1)
|
|
||||||
key = key.transpose(0, 1)
|
|
||||||
value = value.transpose(0, 1)
|
|
||||||
|
|
||||||
# apply attention
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attention_mask, dropout_p=0.0
|
|
||||||
)
|
|
||||||
attn_output = attn_output.transpose(0, 1)
|
|
||||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||||
# TODO: prefer flash attention
|
|
||||||
|
|
||||||
attn_output = self.proj(attn_output)
|
attn_output = self.proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
@ -173,7 +198,7 @@ class Qwen2VLVisionMLP(nn.Module):
|
||||||
class Qwen2VLVisionBlock(nn.Module):
|
class Qwen2VLVisionBlock(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = Qwen2VLSdpaAttention(
|
self.attn = Qwen2VLAttention(
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -194,10 +219,12 @@ class Qwen2VLVisionBlock(nn.Module):
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
def forward(
|
||||||
|
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||||
|
) -> torch.Tensor:
|
||||||
hidden_states_post_norm1, res = self.norm1(hidden_states)
|
hidden_states_post_norm1, res = self.norm1(hidden_states)
|
||||||
hidden_states = hidden_states + self.attn(
|
hidden_states = hidden_states + self.attn(
|
||||||
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb
|
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||||
)
|
)
|
||||||
hidden_states_post_norm2, res = self.norm2(hidden_states)
|
hidden_states_post_norm2, res = self.norm2(hidden_states)
|
||||||
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
|
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
|
||||||
|
@ -220,7 +247,7 @@ class Qwen2VLPatchMerger(nn.Module):
|
||||||
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states, grid_thw) -> torch.Tensor:
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
hidden_states = self.fc1(hidden_states)
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
@ -281,7 +308,6 @@ class Qwen2VisionModel(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
|
||||||
grid_thw: Optional[torch.LongTensor] = None,
|
grid_thw: Optional[torch.LongTensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# reshape the input tensor for processing
|
# reshape the input tensor for processing
|
||||||
|
@ -336,13 +362,13 @@ class Qwen2VisionModel(nn.Module):
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(dim=0, dtype=torch.int32)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||||
# iterately apply the blocks to the hidden states
|
# iterately apply the blocks to the hidden states
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb)
|
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||||
|
|
||||||
# apply the final patch merger to the hidden states
|
# apply the final patch merger to the hidden states
|
||||||
hidden_states = self.merger(hidden_states, grid_thw)
|
hidden_states = self.merger(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue