diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index ad4ea56b..13863650 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" +qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index ad4ea56b..13863650 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" +qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index ad4ea56b..13863650 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -36,6 +36,7 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" +qwen_vl_utils==0.0.8 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index ddb4e36d..fc9e1575 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -14,12 +14,14 @@ # limitations under the License. """PyTorch Qwen2 VL model.""" -from typing import Optional, Tuple, List +from typing import Dict, Optional, Tuple, List import torch import torch.utils.checkpoint from torch import nn from text_generation_server.utils.import_utils import SYSTEM +from qwen_vl_utils import process_vision_info + if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex @@ -411,6 +413,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): self, batch_input_ids: torch.Tensor, image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, # video_grid_thw is not implemented yet as we do not accept video inputs at the moment ) -> Tuple[torch.Tensor, torch.Tensor]: if batch_input_ids.dim() == 1: @@ -424,8 +427,10 @@ class Qwen2VLForConditionalGeneration(nn.Module): device=batch_input_ids.device, ) d = batch_input_ids.device - if image_grid_thw is not None: - image_index = 0 + + # Handle both image and video tokens + if image_grid_thw is not None or video_grid_thw is not None: + vision_index = 0 llm_pos_ids_list = [] for i, input_ids in enumerate(batch_input_ids): @@ -433,34 +438,39 @@ class Qwen2VLForConditionalGeneration(nn.Module): input_ids == self.vision_start_token_id ).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] - # only copy the sum of the image tokens GPU<->CPU + + # Count both image and video tokens image_count = (vision_tokens == self.image_token_id).sum().item() + video_count = (vision_tokens == self.video_token_id).sum().item() current_pos = 0 - for _ in range(image_count): - # copy the value position of the next image token from GPU<->CPU - next_image_pos = ( - (input_ids[current_pos:] == self.image_token_id) + for _ in range(image_count + video_count): + # Find next vision token position (either image or video) + next_vision_pos = ( + ((input_ids[current_pos:] == self.image_token_id) | + (input_ids[current_pos:] == self.video_token_id)) .nonzero()[0] .item() ) - # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop - time_steps, height, width = image_grid_thw[image_index].clone() + + # Determine if current token is video or image + is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id + grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index] + + time_steps, height, width = grid_thw.clone() height //= self.spatial_merge_size width //= self.spatial_merge_size - # calculate the length of the text and image tokens - text_length = next_image_pos - current_pos - start_idx = ( - llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - ) + # Calculate lengths and indices same as before + text_length = next_vision_pos - current_pos + start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 - # text position ids + # Text position ids text_pos_ids = torch.arange(text_length, device=d) text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx llm_pos_ids_list.append(text_pos_ids) - # image position ids + # Vision position ids t_indices = torch.arange(time_steps, device=d).repeat_interleave( height * width ) @@ -469,24 +479,21 @@ class Qwen2VLForConditionalGeneration(nn.Module): .repeat_interleave(width) .repeat(time_steps) ) - w_indices = torch.arange(width, device=d).repeat( - height * time_steps - ) + w_indices = torch.arange(width, device=d).repeat(height * time_steps) - image_pos_ids = ( + vision_pos_ids = ( torch.stack([t_indices, h_indices, w_indices]) + text_length + start_idx ) - llm_pos_ids_list.append(image_pos_ids) + llm_pos_ids_list.append(vision_pos_ids) - current_pos = next_image_pos + time_steps * height * width - image_index += 1 + current_pos = next_vision_pos + time_steps * height * width + vision_index += 1 + # Handle remaining text if any if current_pos < batch_input_ids.size(1): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = batch_input_ids.size(1) - current_pos llm_pos_ids_list.append( torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx @@ -527,11 +534,14 @@ class Qwen2VLForConditionalGeneration(nn.Module): # apply the visual model to the pixel values if they are provided if pixel_values is not None and len(pixel_values) > 0: - if pixel_values is not None: - image_embeds = self.visual( - pixel_values, grid_thw=image_grid_thw - ).squeeze(0) - inputs_embeds[input_ids == self.image_token_id] = image_embeds + vision_embeds = self.visual( + pixel_values, + grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw + ).squeeze(0) + + # Apply embeddings to both image and video tokens + vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id) + inputs_embeds[vision_token_mask] = vision_embeds hidden_states = self.text_model( inputs_embeds=inputs_embeds, @@ -550,3 +560,21 @@ class Qwen2VLForConditionalGeneration(nn.Module): hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits + +class QwenVideoProcessor: + """Utility class to handle video processing specifically for Qwen models""" + + @staticmethod + def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]: + """ + Process messages containing video inputs for Qwen models + Returns a tuple of (processed_messages, video_pixels) + """ + # Use Qwen's built-in video processing + vision_info = process_vision_info(messages) + + if vision_info is not None: + _, video_inputs = vision_info + return video_inputs[0] if video_inputs else None + + return None \ No newline at end of file