2024-01-24 05:08:41 -07:00
|
|
|
import torch
|
2024-05-13 04:44:30 -06:00
|
|
|
from text_generation_server.layers import (
|
2024-01-24 05:08:41 -07:00
|
|
|
TensorParallelEmbedding,
|
|
|
|
)
|
|
|
|
|
2024-01-26 11:04:57 -07:00
|
|
|
|
2024-01-24 05:08:41 -07:00
|
|
|
class ProcessGroup:
|
|
|
|
def __init__(self, rank: int, world_size: int):
|
|
|
|
self._rank = rank
|
|
|
|
self.world_size = world_size
|
|
|
|
|
2024-01-26 11:04:57 -07:00
|
|
|
def size(self) -> int:
|
2024-01-24 05:08:41 -07:00
|
|
|
return self.world_size
|
|
|
|
|
2024-01-26 11:04:57 -07:00
|
|
|
def rank(self) -> int:
|
2024-01-24 05:08:41 -07:00
|
|
|
return self._rank
|
|
|
|
|
2024-01-26 11:04:57 -07:00
|
|
|
|
2024-01-24 05:08:41 -07:00
|
|
|
class Weights:
|
|
|
|
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
|
2024-01-26 11:04:57 -07:00
|
|
|
self.weight = (
|
|
|
|
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
|
|
|
|
)
|
2024-01-24 05:08:41 -07:00
|
|
|
self.process_group = ProcessGroup(rank, world_size)
|
|
|
|
|
2024-01-26 11:04:57 -07:00
|
|
|
def get_partial_sharded(self, name: str, dim: int):
|
2024-01-24 05:08:41 -07:00
|
|
|
assert dim == 0
|
|
|
|
|
|
|
|
rank = self.process_group.rank()
|
|
|
|
world_size = self.process_group.size()
|
|
|
|
size = self.weight.shape[dim]
|
|
|
|
|
|
|
|
block_size = (size + world_size - 1) // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
return self.weight[start:stop]
|
|
|
|
|
|
|
|
def get_shape(self, name: str):
|
|
|
|
return self.weight.shape
|
|
|
|
|
2024-01-26 11:04:57 -07:00
|
|
|
|
2024-01-24 05:08:41 -07:00
|
|
|
def test_weight_hub_files_offline_error():
|
|
|
|
|
2024-01-26 11:04:57 -07:00
|
|
|
vocab_size = 17
|
2024-07-09 12:04:03 -06:00
|
|
|
weights = Weights(
|
|
|
|
rank=0,
|
|
|
|
world_size=1,
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
hidden_dim=256,
|
|
|
|
)
|
2024-01-24 05:08:41 -07:00
|
|
|
embeddings = TensorParallelEmbedding("", weights)
|
|
|
|
|
|
|
|
input_ids = torch.arange(vocab_size)
|
|
|
|
output = embeddings.forward(input_ids)
|
|
|
|
assert embeddings.min_id == 0
|
|
|
|
assert embeddings.max_id == 17
|
|
|
|
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
|
|
|
|
|
2024-01-26 11:04:57 -07:00
|
|
|
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
|
|
|
|
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
|
2024-01-24 05:08:41 -07:00
|
|
|
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
|
|
|
|
assert embeddings_0_2.min_id == 0
|
|
|
|
assert embeddings_0_2.max_id == 9
|
2024-01-26 11:04:57 -07:00
|
|
|
torch.testing.assert_close(
|
|
|
|
embeddings_0_2.weight,
|
|
|
|
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
|
|
|
|
.view(10, 256)
|
|
|
|
.float(),
|
|
|
|
)
|
2024-01-24 05:08:41 -07:00
|
|
|
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
|
|
|
|
assert embeddings_1_2.min_id == 9
|
|
|
|
assert embeddings_1_2.max_id == 17
|
2024-01-26 11:04:57 -07:00
|
|
|
torch.testing.assert_close(
|
|
|
|
embeddings_1_2.weight,
|
|
|
|
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
|
|
|
|
.view(9, 256)
|
|
|
|
.float(),
|
|
|
|
)
|
2024-01-24 05:08:41 -07:00
|
|
|
output_tp_0 = embeddings_0_2.forward(input_ids)
|
|
|
|
output_tp_1 = embeddings_1_2.forward(input_ids)
|
|
|
|
|
|
|
|
torch.testing.assert_close(output, output_tp_0 + output_tp_1)
|