Factor out sharding of packed tensors

For Phi-3-Small I need to shard a packed QKV bias tensor, for which
I implemented the `Weights.get_packed_sharded` method. However, this
method can also replace the `Weights._get_qweight` method and the
custom sharding code from `Weights.get_weights_col_packed`.
This commit is contained in:
Daniël de Kok 2024-06-12 16:20:51 +02:00
parent 376a0b7ada
commit c0f201c9d3
1 changed files with 50 additions and 37 deletions

View File

@ -120,29 +120,57 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
def get_packed_sharded(
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
) -> torch.Tensor:
"""
Get a shard from a tensor that packs multiple tensors.
When a tensor packs multiple tensors (such as QKV or an up
projection + gate projection), sharding with `get_sharded` is not
safe since it would not split the packed tensors across shards.
This method shards a tensor, such that the packed tensors are
split across shards.
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
slice_ = self._get_slice(tensor_name)
total_size = slice_.get_shape()[dim]
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
world_size = self.process_group.size()
rank = self.process_group.rank()
weights = []
tensors = []
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
), f"Prepacked tensor cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
weights.append(slice_[:, block_offset + start : block_offset + stop])
if dim == 0:
tensor = slice_[block_offset + start : block_offset + stop]
elif dim == 1:
tensor = slice_[:, block_offset + start : block_offset + stop]
else:
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
tensors.append(tensor)
block_offset += block_size
tensor = torch.cat(tensors, dim=dim)
tensor = tensor.to(device=self.device)
weight = torch.cat(weights, dim=1)
weight = weight.to(device=self.device)
return weight
# Avoid casting quantizer dtypes.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
return tensor
def get_weights_col_packed_qkv(
self,
@ -175,7 +203,9 @@ class Weights:
from text_generation_server.layers.gptq import GPTQWeight
try:
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
qweight = self.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
@ -183,8 +213,12 @@ class Weights:
bits, groupsize, _, quant_method = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq":
@ -217,34 +251,13 @@ class Weights:
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
B = self.get_packed_sharded(f"{prefix}.B", dim=1, block_sizes=block_sizes)
s = self.get_packed_sharded(f"{prefix}.s", dim=1, block_sizes=block_sizes)
weight = MarlinWeight(B=B, s=s)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
block_sizes = _blocks_to_block_sizes(
total_size=total_size, blocks=block_sizes
weight = self.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
world_size = self.process_group.size()
rank = self.process_group.rank()
tensors = []
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked weights cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensor = slice_[block_offset + start : block_offset + stop]
tensors.append(tensor)
block_offset += block_size
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
return weight
def get_weights_col(self, prefix: str, quantize: str):