refactoring
This commit is contained in:
parent
f7cf45dfde
commit
cee1dea803
|
@ -16,10 +16,14 @@
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple, List
|
from typing import Dict, Optional, Tuple, List
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from contextlib import contextmanager
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
|
||||||
|
@ -561,20 +565,31 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
class QwenVideoProcessor:
|
@contextmanager
|
||||||
"""Utility class to handle video processing specifically for Qwen models"""
|
def temp_video_download(url: str) -> str:
|
||||||
|
"""Downloads video to temporary file and cleans it up after use."""
|
||||||
@staticmethod
|
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
|
||||||
def prepare_video_inputs(messages: List[Dict]) -> Tuple[Dict, Optional[torch.Tensor]]:
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
"""
|
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))
|
||||||
Process messages containing video inputs for Qwen models
|
|
||||||
Returns a tuple of (processed_messages, video_pixels)
|
|
||||||
"""
|
|
||||||
# Use Qwen's built-in video processing
|
|
||||||
vision_info = process_vision_info(messages)
|
|
||||||
|
|
||||||
if vision_info is not None:
|
try:
|
||||||
_, video_inputs = vision_info
|
with open(temp_path, 'wb') as tmp_file:
|
||||||
return video_inputs[0] if video_inputs else None
|
with requests.get(url, stream=True) as r:
|
||||||
|
r.raise_for_status()
|
||||||
return None
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
tmp_file.write(chunk)
|
||||||
|
yield temp_path
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def process_qwen_video(chunk_video: str):
|
||||||
|
"""Process video for Qwen2VL model"""
|
||||||
|
vision_info = [{
|
||||||
|
"type": "video",
|
||||||
|
"video": chunk_video,
|
||||||
|
"max_pixels": 360 * 420,
|
||||||
|
"fps": 1.0
|
||||||
|
}]
|
||||||
|
return process_vision_info(vision_info)
|
|
@ -1,11 +1,7 @@
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import tempfile
|
|
||||||
import requests
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
@ -196,8 +192,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
images.append([image])
|
images.append([image])
|
||||||
elif chunk_type == "video":
|
elif chunk_type == "video":
|
||||||
if config.model_type == "qwen2_vl":
|
if config.model_type == "qwen2_vl":
|
||||||
# For now, treat video URLs as special tokens
|
|
||||||
# This will be processed in the text replacement section below
|
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
@ -222,13 +216,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
)
|
)
|
||||||
image_id += 1
|
image_id += 1
|
||||||
elif chunk_type == "video" and config.model_type == "qwen2_vl":
|
elif chunk_type == "video" and config.model_type == "qwen2_vl":
|
||||||
# Download and process video in a temporary context
|
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
|
||||||
with cls.temp_video_download(chunk.video) as local_path:
|
text, _ = process_qwen_video(chunk.video)
|
||||||
# Now the video is available at local_path for processing
|
full_text += text
|
||||||
full_text += f"<video>{local_path}</video>"
|
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
|
@ -277,27 +269,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@contextmanager
|
|
||||||
def temp_video_download(url: str) -> str:
|
|
||||||
"""Downloads video to a temporary file and cleans it up after use."""
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file:
|
|
||||||
try:
|
|
||||||
# Download video
|
|
||||||
with requests.get(url, stream=True) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
|
||||||
if chunk:
|
|
||||||
tmp_file.write(chunk)
|
|
||||||
tmp_file.flush()
|
|
||||||
yield tmp_file.name
|
|
||||||
finally:
|
|
||||||
# Clean up temp file
|
|
||||||
try:
|
|
||||||
os.unlink(tmp_file.name)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashCausalLM):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
Loading…
Reference in New Issue