Support GPTQ models with column-packed up/gate tensor (#2006)

# What does this PR do?

The GPTQ code path for column-packed packed tensors assumed that this is
always a QKV matrix. However, models (e.g. Phi-3) can also have
column-packed MLP up/gate matrices.

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Daniël de Kok 2024-06-04 19:37:49 +02:00 committed by GitHub
parent 757223b352
commit d14eaacaca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 17 additions and 11 deletions

View File

@ -121,24 +121,30 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str): def _get_qweight(self, name: str, blocks: int):
slice_ = self._get_slice(name) slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1] total_size = slice_.get_shape()[1]
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" assert (
single_size = total_size // 3 total_size % blocks == 0
), f"Prepacked quantized matrix is not divisible by {blocks}"
single_size = total_size // blocks
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
assert ( assert (
single_size % world_size == 0 single_size % world_size == 0
), f"Prepacked quantized qkv cannot be sharded across {world_size} shards" ), f"Prepacked quantized matrix cannot be sharded across {world_size} shards"
block_size = single_size // world_size block_size = single_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
q = slice_[:, start:stop]
k = slice_[:, start + single_size : stop + single_size] weights = []
v = slice_[:, start + 2 * single_size : stop + 2 * single_size] for block in range(blocks):
weight = torch.cat([q, k, v], dim=1) weights.append(
slice_[:, start + block * single_size : stop + block * single_size]
)
weight = torch.cat(weights, dim=1)
weight = weight.to(device=self.device) weight = weight.to(device=self.device)
return weight return weight
@ -157,7 +163,7 @@ class Weights:
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
try: try:
qweight = self._get_qweight(f"{prefix}.qweight") qweight = self._get_qweight(f"{prefix}.qweight", blocks)
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized." f"Cannot load `{quantize}` weight, make sure the model is already quantized."
@ -165,8 +171,8 @@ class Weights:
bits, groupsize, _, quant_method = self._get_gptq_params() bits, groupsize, _, quant_method = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros") qzeros = self._get_qweight(f"{prefix}.qzeros", blocks)
scales = self._get_qweight(f"{prefix}.scales") scales = self._get_qweight(f"{prefix}.scales", blocks)
scales = scales.to(dtype=self.dtype) scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq": if quantize == "gptq" and quant_method == "gptq":