marlin: support tp>1 when group_size==-1
This commit is contained in:
parent
4594e6faba
commit
0d96468ebb
|
@ -513,7 +513,13 @@ class Weights:
|
|||
"Cannot load `marlin` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when group_size == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = self.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue