refactoring

This commit is contained in:
Miquel Farre 2024-11-14 12:30:59 +00:00 committed by drbh
parent f7cf45dfde
commit cee1dea803
2 changed files with 34 additions and 48 deletions

View File

@ -16,10 +16,14 @@
from typing import Dict, Optional, Tuple, List from typing import Dict, Optional, Tuple, List
import os
import tempfile
import requests
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from contextlib import contextmanager
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
@ -561,20 +565,31 @@ class Qwen2VLForConditionalGeneration(nn.Module):
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits
class QwenVideoProcessor: @contextmanager
"""Utility class to handle video processing specifically for Qwen models""" def temp_video_download(url: str) -> str:
"""Downloads video to temporary file and cleans it up after use."""
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))
@staticmethod try:
def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]: with open(temp_path, 'wb') as tmp_file:
""" with requests.get(url, stream=True) as r:
Process messages containing video inputs for Qwen models r.raise_for_status()
Returns a tuple of (processed_messages, video_pixels) for chunk in r.iter_content(chunk_size=8192):
""" if chunk:
# Use Qwen's built-in video processing tmp_file.write(chunk)
vision_info = process_vision_info(messages) yield temp_path
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
if vision_info is not None: def process_qwen_video(chunk_video: str):
_, video_inputs = vision_info """Process video for Qwen2VL model"""
return video_inputs[0] if video_inputs else None vision_info = [{
"type": "video",
return None "video": chunk_video,
"max_pixels": 360 * 420,
"fps": 1.0
}]
return process_vision_info(vision_info)

View File

@ -1,11 +1,7 @@
import os
import torch import torch
import tempfile
import requests
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
from contextlib import contextmanager
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -196,8 +192,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append([image]) images.append([image])
elif chunk_type == "video": elif chunk_type == "video":
if config.model_type == "qwen2_vl": if config.model_type == "qwen2_vl":
# For now, treat video URLs as special tokens
# This will be processed in the text replacement section below
pass pass
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
@ -222,13 +216,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
) )
image_id += 1 image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl": elif chunk_type == "video" and config.model_type == "qwen2_vl":
# Download and process video in a temporary context from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
with cls.temp_video_download(chunk.video) as local_path: text, _ = process_qwen_video(chunk.video)
# Now the video is available at local_path for processing full_text += text
full_text += f"<video>{local_path}</video>"
full_text = image_text_replacement_fixup(config, full_text) full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
@ -278,27 +270,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = None batch.image_grid_thw = None
return batch return batch
@staticmethod
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to a temporary file and cleans it up after use."""
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file:
try:
# Download video
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
tmp_file.flush()
yield tmp_file.name
finally:
# Clean up temp file
try:
os.unlink(tmp_file.name)
except OSError:
pass
class VlmCausalLM(FlashCausalLM): class VlmCausalLM(FlashCausalLM):
def __init__( def __init__(
self, self,