Fix position ids logic instantiation of idefics vision part (#1064)

Problem and fix is described here:
https://huggingface.co/HuggingFaceM4/idefics-9b/discussions/9

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Victor SANH 2023-09-26 15:41:15 +02:00 committed by GitHub
parent ae623b8d2d
commit 1fff6746ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 3 deletions

View File

@ -88,12 +88,10 @@ class IdeficsVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
# self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.position_embedding = TensorParallelEmbedding(
prefix="model.vision_model.embeddings.position_embedding", weights=weights
)
# self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
self.position_ids = weights.get_tensor(f"{prefix}.position_ids")
self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]