fix: create position ids for text only input (#2714)

* fix: create position ids for text only input

* fix: prefer repeat over expand to avoid clone
This commit is contained in:
drbh 2024-11-01 20:40:05 -04:00 committed by GitHub
parent 01dacf8e8f
commit 6e3220529d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 1 deletions

View File

@ -468,7 +468,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[:, i, :] = llm_positions.to(position_ids.device) position_ids[:, i, :] = llm_positions.to(position_ids.device)
else:
position_ids = (
torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device)
.view(1, 1, -1)
.repeat(3, batch_input_ids.shape[0], 1)
)
return position_ids return position_ids
def forward( def forward(