WIP video support

This commit is contained in:
Miquel Farre 2024-11-11 14:52:52 +00:00 committed by drbh
parent 38cff84a3e
commit de6c68443e
4 changed files with 63 additions and 32 deletions

View File

@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -14,12 +14,14 @@
# limitations under the License. # limitations under the License.
"""PyTorch Qwen2 VL model.""" """PyTorch Qwen2 VL model."""
from typing import Optional, Tuple, List from typing import Dict, Optional, Tuple, List
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from qwen_vl_utils import process_vision_info
if SYSTEM == "ipex": if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
@ -411,6 +413,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self, self,
batch_input_ids: torch.Tensor, batch_input_ids: torch.Tensor,
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
# video_grid_thw is not implemented yet as we do not accept video inputs at the moment # video_grid_thw is not implemented yet as we do not accept video inputs at the moment
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if batch_input_ids.dim() == 1: if batch_input_ids.dim() == 1:
@ -424,8 +427,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
device=batch_input_ids.device, device=batch_input_ids.device,
) )
d = batch_input_ids.device d = batch_input_ids.device
if image_grid_thw is not None:
image_index = 0 # Handle both image and video tokens
if image_grid_thw is not None or video_grid_thw is not None:
vision_index = 0
llm_pos_ids_list = [] llm_pos_ids_list = []
for i, input_ids in enumerate(batch_input_ids): for i, input_ids in enumerate(batch_input_ids):
@ -433,34 +438,39 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids == self.vision_start_token_id input_ids == self.vision_start_token_id
).squeeze(1) ).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1] vision_tokens = input_ids[vision_start_indices + 1]
# only copy the sum of the image tokens GPU<->CPU
# Count both image and video tokens
image_count = (vision_tokens == self.image_token_id).sum().item() image_count = (vision_tokens == self.image_token_id).sum().item()
video_count = (vision_tokens == self.video_token_id).sum().item()
current_pos = 0 current_pos = 0
for _ in range(image_count): for _ in range(image_count + video_count):
# copy the value position of the next image token from GPU<->CPU # Find next vision token position (either image or video)
next_image_pos = ( next_vision_pos = (
(input_ids[current_pos:] == self.image_token_id) ((input_ids[current_pos:] == self.image_token_id) |
(input_ids[current_pos:] == self.video_token_id))
.nonzero()[0] .nonzero()[0]
.item() .item()
) )
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
time_steps, height, width = image_grid_thw[image_index].clone() # Determine if current token is video or image
is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id
grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index]
time_steps, height, width = grid_thw.clone()
height //= self.spatial_merge_size height //= self.spatial_merge_size
width //= self.spatial_merge_size width //= self.spatial_merge_size
# calculate the length of the text and image tokens # Calculate lengths and indices same as before
text_length = next_image_pos - current_pos text_length = next_vision_pos - current_pos
start_idx = ( start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
)
# text position ids # Text position ids
text_pos_ids = torch.arange(text_length, device=d) text_pos_ids = torch.arange(text_length, device=d)
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
llm_pos_ids_list.append(text_pos_ids) llm_pos_ids_list.append(text_pos_ids)
# image position ids # Vision position ids
t_indices = torch.arange(time_steps, device=d).repeat_interleave( t_indices = torch.arange(time_steps, device=d).repeat_interleave(
height * width height * width
) )
@ -469,24 +479,21 @@ class Qwen2VLForConditionalGeneration(nn.Module):
.repeat_interleave(width) .repeat_interleave(width)
.repeat(time_steps) .repeat(time_steps)
) )
w_indices = torch.arange(width, device=d).repeat( w_indices = torch.arange(width, device=d).repeat(height * time_steps)
height * time_steps
)
image_pos_ids = ( vision_pos_ids = (
torch.stack([t_indices, h_indices, w_indices]) torch.stack([t_indices, h_indices, w_indices])
+ text_length + text_length
+ start_idx + start_idx
) )
llm_pos_ids_list.append(image_pos_ids) llm_pos_ids_list.append(vision_pos_ids)
current_pos = next_image_pos + time_steps * height * width current_pos = next_vision_pos + time_steps * height * width
image_index += 1 vision_index += 1
# Handle remaining text if any
if current_pos < batch_input_ids.size(1): if current_pos < batch_input_ids.size(1):
st_idx = ( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
text_len = batch_input_ids.size(1) - current_pos text_len = batch_input_ids.size(1) - current_pos
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
@ -527,11 +534,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# apply the visual model to the pixel values if they are provided # apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None: vision_embeds = self.visual(
image_embeds = self.visual( pixel_values,
pixel_values, grid_thw=image_grid_thw grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw
).squeeze(0) ).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
# Apply embeddings to both image and video tokens
vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id)
inputs_embeds[vision_token_mask] = vision_embeds
hidden_states = self.text_model( hidden_states = self.text_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@ -550,3 +560,21 @@ class Qwen2VLForConditionalGeneration(nn.Module):
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits
class QwenVideoProcessor:
"""Utility class to handle video processing specifically for Qwen models"""
@staticmethod
def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]:
"""
Process messages containing video inputs for Qwen models
Returns a tuple of (processed_messages, video_pixels)
"""
# Use Qwen's built-in video processing
vision_info = process_vision_info(messages)
if vision_info is not None:
_, video_inputs = vision_info
return video_inputs[0] if video_inputs else None
return None