marlin: support tp>1 when group_size==-1

This commit is contained in:
Daniël de Kok 2024-06-06 11:51:52 +00:00 committed by Daniël de Kok
parent 4594e6faba
commit 0d96468ebb
1 changed files with 7 additions and 1 deletions

View File

@ -513,7 +513,13 @@ class Weights:
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
s = self.get_sharded(f"{prefix}.s", dim=0)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when group_size == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else: