feat(server): Implements sharding for non divisible `vocab_size`. (#583)
- The code is relatively easy (just disable the checks on Embedding and Head) This cannot be done in the same easy fashion for hidden_dim/head_dim. It's relatively easy on some models (classic MHA) but it would make the other models (MQA) much more complex, and GPTQ quantization another quite hairy piece of code.
This commit is contained in:
parent
2c4bf88268
commit
67347950b7
|
@ -174,13 +174,25 @@ class SuperLayer(nn.Module):
|
|||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.should_gather = should_gather
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
# If the vocab size is not divisible by number of shards
|
||||
# just load the entire thing.
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
else:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ doesn't quantize heads (nor embeddings)
|
||||
if config.quantize == "gptq":
|
||||
|
@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer):
|
|||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
process_group=weights.process_group,
|
||||
should_gather=should_gather,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
world_size = self.process_group.size()
|
||||
if world_size == 1:
|
||||
if not self.should_gather:
|
||||
return super().forward(input)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||
out_dim = self.linear.weight.shape[0]
|
||||
|
||||
|
@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||
class TensorParallelEmbedding(nn.Module):
|
||||
def __init__(self, prefix: str, weights, reduce=True):
|
||||
super().__init__()
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||
|
||||
process_group = weights.process_group
|
||||
|
|
|
@ -69,7 +69,7 @@ class Weights:
|
|||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
@ -81,10 +81,6 @@ class Weights:
|
|||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[start:stop]
|
||||
elif dim == 1:
|
||||
|
@ -98,6 +94,17 @@ class Weights:
|
|||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
size = slice_.get_shape()[dim]
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue