fix: create position ids for text only input (#2714)
* fix: create position ids for text only input * fix: prefer repeat over expand to avoid clone
This commit is contained in:
parent
01dacf8e8f
commit
6e3220529d
|
@ -468,7 +468,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
position_ids[:, i, :] = llm_positions.to(position_ids.device)
|
||||
|
||||
else:
|
||||
position_ids = (
|
||||
torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device)
|
||||
.view(1, 1, -1)
|
||||
.repeat(3, batch_input_ids.shape[0], 1)
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def forward(
|
||||
|
|
Loading…
Reference in New Issue