refactoring

This commit is contained in:
Miquel Farre 2024-11-14 12:30:59 +00:00 committed by drbh
parent f7cf45dfde
commit cee1dea803
2 changed files with 34 additions and 48 deletions

View File

@ -16,10 +16,14 @@
from typing import Dict, Optional, Tuple, List from typing import Dict, Optional, Tuple, List
import os
import tempfile
import requests
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from contextlib import contextmanager
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
@ -561,20 +565,31 @@ class Qwen2VLForConditionalGeneration(nn.Module):
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits
class QwenVideoProcessor: @contextmanager
"""Utility class to handle video processing specifically for Qwen models""" def temp_video_download(url: str) -> str:
"""Downloads video to temporary file and cleans it up after use."""
@staticmethod temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]: os.makedirs(temp_dir, exist_ok=True)
""" temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))
Process messages containing video inputs for Qwen models
Returns a tuple of (processed_messages, video_pixels)
"""
# Use Qwen's built-in video processing
vision_info = process_vision_info(messages)
if vision_info is not None: try:
_, video_inputs = vision_info with open(temp_path, 'wb') as tmp_file:
return video_inputs[0] if video_inputs else None with requests.get(url, stream=True) as r:
r.raise_for_status()
return None for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
yield temp_path
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
def process_qwen_video(chunk_video: str):
"""Process video for Qwen2VL model"""
vision_info = [{
"type": "video",
"video": chunk_video,
"max_pixels": 360 * 420,
"fps": 1.0
}]
return process_vision_info(vision_info)

View File

@ -1,11 +1,7 @@
import os
import torch import torch
import tempfile
import requests
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
from contextlib import contextmanager
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -196,8 +192,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append([image]) images.append([image])
elif chunk_type == "video": elif chunk_type == "video":
if config.model_type == "qwen2_vl": if config.model_type == "qwen2_vl":
# For now, treat video URLs as special tokens
# This will be processed in the text replacement section below
pass pass
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
@ -222,13 +216,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
) )
image_id += 1 image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl": elif chunk_type == "video" and config.model_type == "qwen2_vl":
# Download and process video in a temporary context from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
with cls.temp_video_download(chunk.video) as local_path: text, _ = process_qwen_video(chunk.video)
# Now the video is available at local_path for processing full_text += text
full_text += f"<video>{local_path}</video>"
full_text = image_text_replacement_fixup(config, full_text) full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
@ -277,27 +269,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
return batch return batch
@staticmethod
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to a temporary file and cleans it up after use."""
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file:
try:
# Download video
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
tmp_file.flush()
yield tmp_file.name
finally:
# Clean up temp file
try:
os.unlink(tmp_file.name)
except OSError:
pass
class VlmCausalLM(FlashCausalLM): class VlmCausalLM(FlashCausalLM):
def __init__( def __init__(