Fix Phi-2 with `tp>1` (#2003)

# What does this PR do?

We were using the wrong parallelism in the up-projection.

<!-- Remove if not applicable -->

## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Daniël de Kok 2024-06-04 14:26:07 +02:00 committed by GitHub
parent df71aafdcc
commit 9b52f0e2dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -238,7 +238,7 @@ class PhiMLP(nn.Module):
)
# llama weights are up_proj and down_proj and bias=False
self.up_proj = TensorParallelRowLinear.load(
self.up_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.fc1",
weights=weights,