Factor out sharding of packed tensors
For Phi-3-Small I need to shard a packed QKV bias tensor, for which I implemented the `Weights.get_packed_sharded` method. However, this method can also replace the `Weights._get_qweight` method and the custom sharding code from `Weights.get_weights_col_packed`.
This commit is contained in:
parent
376a0b7ada
commit
c0f201c9d3
|
@ -120,29 +120,57 @@ class Weights:
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim)
|
return self.get_partial_sharded(tensor_name, dim)
|
||||||
|
|
||||||
def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
|
def get_packed_sharded(
|
||||||
slice_ = self._get_slice(name)
|
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
|
||||||
total_size = slice_.get_shape()[1]
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get a shard from a tensor that packs multiple tensors.
|
||||||
|
|
||||||
|
When a tensor packs multiple tensors (such as QKV or an up
|
||||||
|
projection + gate projection), sharding with `get_sharded` is not
|
||||||
|
safe since it would not split the packed tensors across shards.
|
||||||
|
|
||||||
|
This method shards a tensor, such that the packed tensors are
|
||||||
|
split across shards.
|
||||||
|
|
||||||
|
The columns are split in equally sized blocks when blocks is an `int`, or
|
||||||
|
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
||||||
|
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
||||||
|
convenient for e.g. splitting QKV without knowing the storage details of
|
||||||
|
quantized weights.
|
||||||
|
"""
|
||||||
|
slice_ = self._get_slice(tensor_name)
|
||||||
|
total_size = slice_.get_shape()[dim]
|
||||||
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
|
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
|
||||||
|
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
weights = []
|
tensors = []
|
||||||
block_offset = 0
|
block_offset = 0
|
||||||
for block_size in block_sizes:
|
for block_size in block_sizes:
|
||||||
assert (
|
assert (
|
||||||
block_size % world_size == 0
|
block_size % world_size == 0
|
||||||
), f"Prepacked qkv cannot be sharded across {world_size} shards"
|
), f"Prepacked tensor cannot be sharded across {world_size} shards"
|
||||||
shard_block_size = block_size // world_size
|
shard_block_size = block_size // world_size
|
||||||
start = rank * shard_block_size
|
start = rank * shard_block_size
|
||||||
stop = (rank + 1) * shard_block_size
|
stop = (rank + 1) * shard_block_size
|
||||||
weights.append(slice_[:, block_offset + start : block_offset + stop])
|
if dim == 0:
|
||||||
|
tensor = slice_[block_offset + start : block_offset + stop]
|
||||||
|
elif dim == 1:
|
||||||
|
tensor = slice_[:, block_offset + start : block_offset + stop]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
|
||||||
|
tensors.append(tensor)
|
||||||
block_offset += block_size
|
block_offset += block_size
|
||||||
|
tensor = torch.cat(tensors, dim=dim)
|
||||||
|
tensor = tensor.to(device=self.device)
|
||||||
|
|
||||||
weight = torch.cat(weights, dim=1)
|
# Avoid casting quantizer dtypes.
|
||||||
weight = weight.to(device=self.device)
|
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
||||||
return weight
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
def get_weights_col_packed_qkv(
|
def get_weights_col_packed_qkv(
|
||||||
self,
|
self,
|
||||||
|
@ -175,7 +203,9 @@ class Weights:
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
|
qweight = self.get_packed_sharded(
|
||||||
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||||
|
@ -183,8 +213,12 @@ class Weights:
|
||||||
|
|
||||||
bits, groupsize, _, quant_method = self._get_gptq_params()
|
bits, groupsize, _, quant_method = self._get_gptq_params()
|
||||||
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
|
qzeros = self.get_packed_sharded(
|
||||||
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
scales = self.get_packed_sharded(
|
||||||
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
|
|
||||||
if quantize == "gptq" and quant_method == "gptq":
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
|
@ -217,34 +251,13 @@ class Weights:
|
||||||
elif quantize == "marlin":
|
elif quantize == "marlin":
|
||||||
from text_generation_server.layers.marlin import MarlinWeight
|
from text_generation_server.layers.marlin import MarlinWeight
|
||||||
|
|
||||||
B = self._get_qweight(f"{prefix}.B", block_sizes)
|
B = self.get_packed_sharded(f"{prefix}.B", dim=1, block_sizes=block_sizes)
|
||||||
s = self._get_qweight(f"{prefix}.s", block_sizes)
|
s = self.get_packed_sharded(f"{prefix}.s", dim=1, block_sizes=block_sizes)
|
||||||
weight = MarlinWeight(B=B, s=s)
|
weight = MarlinWeight(B=B, s=s)
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
weight = self.get_packed_sharded(
|
||||||
total_size = slice_.get_shape()[0]
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
block_sizes = _blocks_to_block_sizes(
|
|
||||||
total_size=total_size, blocks=block_sizes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
world_size = self.process_group.size()
|
|
||||||
rank = self.process_group.rank()
|
|
||||||
|
|
||||||
tensors = []
|
|
||||||
block_offset = 0
|
|
||||||
for block_size in block_sizes:
|
|
||||||
assert (
|
|
||||||
block_size % world_size == 0
|
|
||||||
), f"Prepacked weights cannot be sharded across {world_size} shards"
|
|
||||||
shard_block_size = block_size // world_size
|
|
||||||
start = rank * shard_block_size
|
|
||||||
stop = (rank + 1) * shard_block_size
|
|
||||||
tensor = slice_[block_offset + start : block_offset + stop]
|
|
||||||
tensors.append(tensor)
|
|
||||||
block_offset += block_size
|
|
||||||
weight = torch.cat(tensors, dim=0)
|
|
||||||
weight = weight.to(device=self.device)
|
|
||||||
weight = weight.to(dtype=self.dtype)
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_weights_col(self, prefix: str, quantize: str):
|
def get_weights_col(self, prefix: str, quantize: str):
|
||||||
|
|
Loading…
Reference in New Issue