WIP video support
This commit is contained in:
parent
38cff84a3e
commit
de6c68443e
|
@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -14,12 +14,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Qwen2 VL model."""
|
"""PyTorch Qwen2 VL model."""
|
||||||
|
|
||||||
from typing import Optional, Tuple, List
|
from typing import Dict, Optional, Tuple, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
@ -411,6 +413,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
self,
|
self,
|
||||||
batch_input_ids: torch.Tensor,
|
batch_input_ids: torch.Tensor,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if batch_input_ids.dim() == 1:
|
if batch_input_ids.dim() == 1:
|
||||||
|
@ -424,8 +427,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
device=batch_input_ids.device,
|
device=batch_input_ids.device,
|
||||||
)
|
)
|
||||||
d = batch_input_ids.device
|
d = batch_input_ids.device
|
||||||
if image_grid_thw is not None:
|
|
||||||
image_index = 0
|
# Handle both image and video tokens
|
||||||
|
if image_grid_thw is not None or video_grid_thw is not None:
|
||||||
|
vision_index = 0
|
||||||
llm_pos_ids_list = []
|
llm_pos_ids_list = []
|
||||||
|
|
||||||
for i, input_ids in enumerate(batch_input_ids):
|
for i, input_ids in enumerate(batch_input_ids):
|
||||||
|
@ -433,34 +438,39 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
input_ids == self.vision_start_token_id
|
input_ids == self.vision_start_token_id
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
vision_tokens = input_ids[vision_start_indices + 1]
|
vision_tokens = input_ids[vision_start_indices + 1]
|
||||||
# only copy the sum of the image tokens GPU<->CPU
|
|
||||||
|
# Count both image and video tokens
|
||||||
image_count = (vision_tokens == self.image_token_id).sum().item()
|
image_count = (vision_tokens == self.image_token_id).sum().item()
|
||||||
|
video_count = (vision_tokens == self.video_token_id).sum().item()
|
||||||
|
|
||||||
current_pos = 0
|
current_pos = 0
|
||||||
for _ in range(image_count):
|
for _ in range(image_count + video_count):
|
||||||
# copy the value position of the next image token from GPU<->CPU
|
# Find next vision token position (either image or video)
|
||||||
next_image_pos = (
|
next_vision_pos = (
|
||||||
(input_ids[current_pos:] == self.image_token_id)
|
((input_ids[current_pos:] == self.image_token_id) |
|
||||||
|
(input_ids[current_pos:] == self.video_token_id))
|
||||||
.nonzero()[0]
|
.nonzero()[0]
|
||||||
.item()
|
.item()
|
||||||
)
|
)
|
||||||
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
|
||||||
time_steps, height, width = image_grid_thw[image_index].clone()
|
# Determine if current token is video or image
|
||||||
|
is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id
|
||||||
|
grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index]
|
||||||
|
|
||||||
|
time_steps, height, width = grid_thw.clone()
|
||||||
height //= self.spatial_merge_size
|
height //= self.spatial_merge_size
|
||||||
width //= self.spatial_merge_size
|
width //= self.spatial_merge_size
|
||||||
|
|
||||||
# calculate the length of the text and image tokens
|
# Calculate lengths and indices same as before
|
||||||
text_length = next_image_pos - current_pos
|
text_length = next_vision_pos - current_pos
|
||||||
start_idx = (
|
start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# text position ids
|
# Text position ids
|
||||||
text_pos_ids = torch.arange(text_length, device=d)
|
text_pos_ids = torch.arange(text_length, device=d)
|
||||||
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
||||||
llm_pos_ids_list.append(text_pos_ids)
|
llm_pos_ids_list.append(text_pos_ids)
|
||||||
|
|
||||||
# image position ids
|
# Vision position ids
|
||||||
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
||||||
height * width
|
height * width
|
||||||
)
|
)
|
||||||
|
@ -469,24 +479,21 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
.repeat_interleave(width)
|
.repeat_interleave(width)
|
||||||
.repeat(time_steps)
|
.repeat(time_steps)
|
||||||
)
|
)
|
||||||
w_indices = torch.arange(width, device=d).repeat(
|
w_indices = torch.arange(width, device=d).repeat(height * time_steps)
|
||||||
height * time_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
image_pos_ids = (
|
vision_pos_ids = (
|
||||||
torch.stack([t_indices, h_indices, w_indices])
|
torch.stack([t_indices, h_indices, w_indices])
|
||||||
+ text_length
|
+ text_length
|
||||||
+ start_idx
|
+ start_idx
|
||||||
)
|
)
|
||||||
llm_pos_ids_list.append(image_pos_ids)
|
llm_pos_ids_list.append(vision_pos_ids)
|
||||||
|
|
||||||
current_pos = next_image_pos + time_steps * height * width
|
current_pos = next_vision_pos + time_steps * height * width
|
||||||
image_index += 1
|
vision_index += 1
|
||||||
|
|
||||||
|
# Handle remaining text if any
|
||||||
if current_pos < batch_input_ids.size(1):
|
if current_pos < batch_input_ids.size(1):
|
||||||
st_idx = (
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
||||||
)
|
|
||||||
text_len = batch_input_ids.size(1) - current_pos
|
text_len = batch_input_ids.size(1) - current_pos
|
||||||
llm_pos_ids_list.append(
|
llm_pos_ids_list.append(
|
||||||
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
|
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
@ -527,11 +534,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
|
|
||||||
# apply the visual model to the pixel values if they are provided
|
# apply the visual model to the pixel values if they are provided
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
if pixel_values is not None:
|
vision_embeds = self.visual(
|
||||||
image_embeds = self.visual(
|
pixel_values,
|
||||||
pixel_values, grid_thw=image_grid_thw
|
grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
|
||||||
|
# Apply embeddings to both image and video tokens
|
||||||
|
vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id)
|
||||||
|
inputs_embeds[vision_token_mask] = vision_embeds
|
||||||
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -550,3 +560,21 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
class QwenVideoProcessor:
|
||||||
|
"""Utility class to handle video processing specifically for Qwen models"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Process messages containing video inputs for Qwen models
|
||||||
|
Returns a tuple of (processed_messages, video_pixels)
|
||||||
|
"""
|
||||||
|
# Use Qwen's built-in video processing
|
||||||
|
vision_info = process_vision_info(messages)
|
||||||
|
|
||||||
|
if vision_info is not None:
|
||||||
|
_, video_inputs = vision_info
|
||||||
|
return video_inputs[0] if video_inputs else None
|
||||||
|
|
||||||
|
return None
|
Loading…
Reference in New Issue