refactoring

This commit is contained in:
Miquel Farre 2024-11-14 12:30:59 +00:00 committed by drbh
parent f7cf45dfde
commit cee1dea803
2 changed files with 34 additions and 48 deletions

View File

@ -16,10 +16,14 @@
from typing import Dict, Optional, Tuple, List
import os
import tempfile
import requests
import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
from contextlib import contextmanager
from qwen_vl_utils import process_vision_info
@ -561,20 +565,31 @@ class Qwen2VLForConditionalGeneration(nn.Module):
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
class QwenVideoProcessor:
"""Utility class to handle video processing specifically for Qwen models"""
@staticmethod
def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]:
"""
Process messages containing video inputs for Qwen models
Returns a tuple of (processed_messages, video_pixels)
"""
# Use Qwen's built-in video processing
vision_info = process_vision_info(messages)
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to temporary file and cleans it up after use."""
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))
if vision_info is not None:
_, video_inputs = vision_info
return video_inputs[0] if video_inputs else None
return None
try:
with open(temp_path, 'wb') as tmp_file:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
yield temp_path
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
def process_qwen_video(chunk_video: str):
"""Process video for Qwen2VL model"""
vision_info = [{
"type": "video",
"video": chunk_video,
"max_pixels": 360 * 420,
"fps": 1.0
}]
return process_vision_info(vision_info)

View File

@ -1,11 +1,7 @@
import os
import torch
import tempfile
import requests
from PIL import Image
from io import BytesIO
from contextlib import contextmanager
from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -196,8 +192,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append([image])
elif chunk_type == "video":
if config.model_type == "qwen2_vl":
# For now, treat video URLs as special tokens
# This will be processed in the text replacement section below
pass
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
@ -222,13 +216,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
)
image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl":
# Download and process video in a temporary context
with cls.temp_video_download(chunk.video) as local_path:
# Now the video is available at local_path for processing
full_text += f"<video>{local_path}</video>"
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
text, _ = process_qwen_video(chunk.video)
full_text += text
full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
@ -277,27 +269,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@staticmethod
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to a temporary file and cleans it up after use."""
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file:
try:
# Download video
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
tmp_file.flush()
yield tmp_file.name
finally:
# Clean up temp file
try:
os.unlink(tmp_file.name)
except OSError:
pass
class VlmCausalLM(FlashCausalLM):
def __init__(