feat(server): Implements sharding for non divisible `vocab_size`. (#583)

- The code is relatively easy (just disable the checks on Embedding and
Head)

This cannot be done in the same easy fashion for hidden_dim/head_dim.
It's relatively easy on some models (classic MHA) but it would make the
other
models (MQA) much more complex, and GPTQ quantization another quite
hairy piece
of code.
This commit is contained in:
Nicolas Patry 2023-07-12 16:43:31 +02:00 committed by GitHub
parent 2c4bf88268
commit 67347950b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 10 deletions

View File

@ -174,13 +174,25 @@ class SuperLayer(nn.Module):
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group):
def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear)
self.process_group = process_group
self.should_gather = should_gather
@staticmethod
def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq":
@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer):
return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group,
should_gather=should_gather,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
world_size = self.process_group.size()
if world_size == 1:
if not self.should_gather:
return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]
@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer):
class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group

View File

@ -69,7 +69,7 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
@ -81,10 +81,6 @@ class Weights:
start = rank * block_size
stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
@ -98,6 +94,17 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try: