refactoring
This commit is contained in:
parent
f7cf45dfde
commit
cee1dea803
|
@ -16,10 +16,14 @@
|
|||
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import requests
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from contextlib import contextmanager
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
|
||||
|
@ -561,20 +565,31 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
||||
class QwenVideoProcessor:
|
||||
"""Utility class to handle video processing specifically for Qwen models"""
|
||||
|
||||
@staticmethod
|
||||
def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Process messages containing video inputs for Qwen models
|
||||
Returns a tuple of (processed_messages, video_pixels)
|
||||
"""
|
||||
# Use Qwen's built-in video processing
|
||||
vision_info = process_vision_info(messages)
|
||||
@contextmanager
|
||||
def temp_video_download(url: str) -> str:
|
||||
"""Downloads video to temporary file and cleans it up after use."""
|
||||
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))
|
||||
|
||||
if vision_info is not None:
|
||||
_, video_inputs = vision_info
|
||||
return video_inputs[0] if video_inputs else None
|
||||
|
||||
return None
|
||||
try:
|
||||
with open(temp_path, 'wb') as tmp_file:
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
tmp_file.write(chunk)
|
||||
yield temp_path
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
def process_qwen_video(chunk_video: str):
|
||||
"""Process video for Qwen2VL model"""
|
||||
vision_info = [{
|
||||
"type": "video",
|
||||
"video": chunk_video,
|
||||
"max_pixels": 360 * 420,
|
||||
"fps": 1.0
|
||||
}]
|
||||
return process_vision_info(vision_info)
|
|
@ -1,11 +1,7 @@
|
|||
import os
|
||||
import torch
|
||||
import tempfile
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
@ -196,8 +192,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
images.append([image])
|
||||
elif chunk_type == "video":
|
||||
if config.model_type == "qwen2_vl":
|
||||
# For now, treat video URLs as special tokens
|
||||
# This will be processed in the text replacement section below
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
@ -222,13 +216,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
)
|
||||
image_id += 1
|
||||
elif chunk_type == "video" and config.model_type == "qwen2_vl":
|
||||
# Download and process video in a temporary context
|
||||
with cls.temp_video_download(chunk.video) as local_path:
|
||||
# Now the video is available at local_path for processing
|
||||
full_text += f"<video>{local_path}</video>"
|
||||
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
|
||||
text, _ = process_qwen_video(chunk.video)
|
||||
full_text += text
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
|
@ -277,27 +269,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def temp_video_download(url: str) -> str:
|
||||
"""Downloads video to a temporary file and cleans it up after use."""
|
||||
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file:
|
||||
try:
|
||||
# Download video
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
tmp_file.write(chunk)
|
||||
tmp_file.flush()
|
||||
yield tmp_file.name
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
os.unlink(tmp_file.name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
|
|
Loading…
Reference in New Issue