This commit is contained in:
Miquel Farre 2024-11-14 13:40:04 +00:00 committed by drbh
parent cee1dea803
commit 6b4697e9d1
1 changed files with 2 additions and 29 deletions

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch Qwen2 VL model.""" """PyTorch Qwen2 VL model."""
__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']
from typing import Dict, Optional, Tuple, List from typing import Dict, Optional, Tuple, List
import os import os
@ -564,32 +566,3 @@ class Qwen2VLForConditionalGeneration(nn.Module):
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to temporary file and cleans it up after use."""
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))
try:
with open(temp_path, 'wb') as tmp_file:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
yield temp_path
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
def process_qwen_video(chunk_video: str):
"""Process video for Qwen2VL model"""
vision_info = [{
"type": "video",
"video": chunk_video,
"max_pixels": 360 * 420,
"fps": 1.0
}]
return process_vision_info(vision_info)