fix
This commit is contained in:
parent
cee1dea803
commit
6b4697e9d1
|
@ -14,6 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Qwen2 VL model."""
|
"""PyTorch Qwen2 VL model."""
|
||||||
|
|
||||||
|
__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple, List
|
from typing import Dict, Optional, Tuple, List
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
@ -564,32 +566,3 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def temp_video_download(url: str) -> str:
|
|
||||||
"""Downloads video to temporary file and cleans it up after use."""
|
|
||||||
temp_dir = os.path.join(tempfile.gettempdir(), "qwen_videos")
|
|
||||||
os.makedirs(temp_dir, exist_ok=True)
|
|
||||||
temp_path = os.path.abspath(os.path.join(temp_dir, os.path.basename(url)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(temp_path, 'wb') as tmp_file:
|
|
||||||
with requests.get(url, stream=True) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
|
||||||
if chunk:
|
|
||||||
tmp_file.write(chunk)
|
|
||||||
yield temp_path
|
|
||||||
finally:
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
os.unlink(temp_path)
|
|
||||||
|
|
||||||
def process_qwen_video(chunk_video: str):
|
|
||||||
"""Process video for Qwen2VL model"""
|
|
||||||
vision_info = [{
|
|
||||||
"type": "video",
|
|
||||||
"video": chunk_video,
|
|
||||||
"max_pixels": 360 * 420,
|
|
||||||
"fps": 1.0
|
|
||||||
}]
|
|
||||||
return process_vision_info(vision_info)
|
|
Loading…
Reference in New Issue