Fix position ids logic instantiation of idefics vision part (#1064)
Problem and fix is described here: https://huggingface.co/HuggingFaceM4/idefics-9b/discussions/9 --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
ae623b8d2d
commit
1fff6746ab
|
@ -88,12 +88,10 @@ class IdeficsVisionEmbeddings(nn.Module):
|
|||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
# self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix="model.vision_model.embeddings.position_embedding", weights=weights
|
||||
)
|
||||
# self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
self.position_ids = weights.get_tensor(f"{prefix}.position_ids")
|
||||
self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
|
Loading…
Reference in New Issue