import torch from text_generation_server.layers import ( TensorParallelEmbedding, ) from text_generation_server.utils.weights import DefaultWeightsLoader class ProcessGroup: def __init__(self, rank: int, world_size: int): self._rank = rank self.world_size = world_size def size(self) -> int: return self.world_size def rank(self) -> int: return self._rank class Weights: def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int): self.weight = ( torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim) ) self.process_group = ProcessGroup(rank, world_size) def get_partial_sharded(self, name: str, dim: int): assert dim == 0 rank = self.process_group.rank() world_size = self.process_group.size() size = self.weight.shape[dim] block_size = (size + world_size - 1) // world_size start = rank * block_size stop = (rank + 1) * block_size return self.weight[start:stop] def get_shape(self, name: str): return self.weight.shape def test_weight_hub_files_offline_error(): vocab_size = 17 weights = Weights( rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256, ) embeddings = TensorParallelEmbedding("", weights) input_ids = torch.arange(vocab_size) output = embeddings.forward(input_ids) assert embeddings.min_id == 0 assert embeddings.max_id == 17 torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256)) weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256) weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256) embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False) assert embeddings_0_2.min_id == 0 assert embeddings_0_2.max_id == 9 torch.testing.assert_close( embeddings_0_2.weight, torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0) .view(10, 256) .float(), ) embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False) assert embeddings_1_2.min_id == 9 assert embeddings_1_2.max_id == 17 torch.testing.assert_close( embeddings_1_2.weight, torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0) .view(9, 256) .float(), ) output_tp_0 = embeddings_0_2.forward(input_ids) output_tp_1 = embeddings_1_2.forward(input_ids) torch.testing.assert_close(output, output_tp_0 + output_tp_1)