fix: create position ids for text only input (#2714)
* fix: create position ids for text only input * fix: prefer repeat over expand to avoid clone
This commit is contained in:
parent
01dacf8e8f
commit
6e3220529d
|
@ -468,7 +468,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
|
|
||||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
position_ids[:, i, :] = llm_positions.to(position_ids.device)
|
position_ids[:, i, :] = llm_positions.to(position_ids.device)
|
||||||
|
else:
|
||||||
|
position_ids = (
|
||||||
|
torch.arange(batch_input_ids.shape[1], device=batch_input_ids.device)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.repeat(3, batch_input_ids.shape[0], 1)
|
||||||
|
)
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
Loading…
Reference in New Issue