diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2dfd80bf..71d67d82 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -121,24 +121,30 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str): + def _get_qweight(self, name: str, blocks: int): slice_ = self._get_slice(name) total_size = slice_.get_shape()[1] - assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" - single_size = total_size // 3 + assert ( + total_size % blocks == 0 + ), f"Prepacked quantized matrix is not divisible by {blocks}" + single_size = total_size // blocks world_size = self.process_group.size() rank = self.process_group.rank() assert ( single_size % world_size == 0 - ), f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + ), f"Prepacked quantized matrix cannot be sharded across {world_size} shards" block_size = single_size // world_size start = rank * block_size stop = (rank + 1) * block_size - q = slice_[:, start:stop] - k = slice_[:, start + single_size : stop + single_size] - v = slice_[:, start + 2 * single_size : stop + 2 * single_size] - weight = torch.cat([q, k, v], dim=1) + + weights = [] + for block in range(blocks): + weights.append( + slice_[:, start + block * single_size : stop + block * single_size] + ) + + weight = torch.cat(weights, dim=1) weight = weight.to(device=self.device) return weight @@ -157,7 +163,7 @@ class Weights: from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight") + qweight = self._get_qweight(f"{prefix}.qweight", blocks) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -165,8 +171,8 @@ class Weights: bits, groupsize, _, quant_method = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros") - scales = self._get_qweight(f"{prefix}.scales") + qzeros = self._get_qweight(f"{prefix}.qzeros", blocks) + scales = self._get_qweight(f"{prefix}.scales", blocks) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and quant_method == "gptq":