From 7e542d4d05513575f9eb950f961c5b4c574c9c29 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 24 Jan 2024 13:08:41 +0100 Subject: [PATCH] Fixing non divisible embeddings. (#1476) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../test_idefics/test_idefics.json | 55 +++-- .../test_idefics/test_idefics_load.json | 214 +++++++++--------- server/tests/utils/test_layers.py | 64 ++++++ server/text_generation_server/utils/layers.py | 4 +- .../text_generation_server/utils/weights.py | 2 +- 5 files changed, 199 insertions(+), 140 deletions(-) create mode 100644 server/tests/utils/test_layers.py diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics.json index 2c5d05f6..90fb6dcc 100644 --- a/integration-tests/models/__snapshots__/test_idefics/test_idefics.json +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics.json @@ -11,92 +11,92 @@ }, { "id": 4911, - "logprob": -5.7851562, + "logprob": -6.9765625, "text": "User" }, { "id": 29901, - "logprob": -0.006996155, + "logprob": -0.0059432983, "text": ":" }, { "id": 32000, - "logprob": -0.81347656, + "logprob": -0.8408203, "text": "" }, { "id": 32001, - "logprob": -6.687641e-05, + "logprob": -9.906292e-05, "text": "" }, { "id": 32000, - "logprob": -3.5762787e-07, + "logprob": -2.3841858e-07, "text": "" }, { "id": 1815, - "logprob": -4.2148438, + "logprob": -4.1679688, "text": "Can" }, { "id": 366, - "logprob": -0.014137268, + "logprob": -0.014099121, "text": "you" }, { "id": 2649, - "logprob": -4.4335938, + "logprob": -4.4609375, "text": "tell" }, { "id": 592, - "logprob": -0.2919922, + "logprob": -0.29882812, "text": "me" }, { "id": 263, - "logprob": -4.2070312, + "logprob": -4.1445312, "text": "a" }, { "id": 1407, - "logprob": -9.421875, + "logprob": -9.3828125, "text": "very" }, { "id": 3273, - "logprob": -1.8720703, + "logprob": -1.9736328, "text": "short" }, { "id": 5828, - "logprob": -0.26489258, + "logprob": -0.2800293, "text": "story" }, { "id": 2729, - "logprob": -3.7441406, + "logprob": -3.5625, "text": "based" }, { "id": 373, - "logprob": -0.0005393028, + "logprob": -0.0006427765, "text": "on" }, { "id": 278, - "logprob": -0.140625, + "logprob": -0.13952637, "text": "the" }, { "id": 1967, - "logprob": -0.06756592, + "logprob": -0.068115234, "text": "image" }, { "id": 29973, - "logprob": -0.15454102, + "logprob": -0.16357422, "text": "?" } ], @@ -104,25 +104,25 @@ "tokens": [ { "id": 32002, - "logprob": -0.0019140244, + "logprob": -0.0026474, "special": true, "text": "" }, { "id": 29871, - "logprob": -8.404255e-05, + "logprob": -8.547306e-05, "special": false, "text": " " }, { "id": 13, - "logprob": -1.7642975e-05, + "logprob": -1.7881393e-05, "special": false, "text": "\n" }, { "id": 7900, - "logprob": -2.9802322e-06, + "logprob": -3.0994415e-06, "special": false, "text": "Ass" }, @@ -140,30 +140,29 @@ }, { "id": 319, - "logprob": -0.91064453, + "logprob": -0.92529297, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2412109, + "logprob": -1.1269531, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.0002439022, + "logprob": -0.00029492378, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1630859, + "logprob": -1.1855469, "special": false, "text": " stands" } - ], - "top_tokens": null + ] }, "generated_text": " \nAssistant: A rooster stands" } diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json index f258e38d..21d6161b 100644 --- a/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json @@ -12,92 +12,92 @@ }, { "id": 4911, - "logprob": -5.7851562, + "logprob": -6.9804688, "text": "User" }, { "id": 29901, - "logprob": -0.006996155, + "logprob": -0.006122589, "text": ":" }, { "id": 32000, - "logprob": -0.81347656, + "logprob": -0.8417969, "text": "" }, { "id": 32001, - "logprob": -6.687641e-05, + "logprob": -9.918213e-05, "text": "" }, { "id": 32000, - "logprob": -3.5762787e-07, + "logprob": -2.3841858e-07, "text": "" }, { "id": 1815, - "logprob": -4.2148438, + "logprob": -4.1679688, "text": "Can" }, { "id": 366, - "logprob": -0.014137268, + "logprob": -0.014091492, "text": "you" }, { "id": 2649, - "logprob": -4.4335938, + "logprob": -4.4726562, "text": "tell" }, { "id": 592, - "logprob": -0.2919922, + "logprob": -0.2998047, "text": "me" }, { "id": 263, - "logprob": -4.2070312, + "logprob": -4.15625, "text": "a" }, { "id": 1407, - "logprob": -9.421875, + "logprob": -9.3828125, "text": "very" }, { "id": 3273, - "logprob": -1.8720703, + "logprob": -1.9716797, "text": "short" }, { "id": 5828, - "logprob": -0.26489258, + "logprob": -0.27734375, "text": "story" }, { "id": 2729, - "logprob": -3.7441406, + "logprob": -3.5605469, "text": "based" }, { "id": 373, - "logprob": -0.0005393028, + "logprob": -0.00064468384, "text": "on" }, { "id": 278, - "logprob": -0.140625, + "logprob": -0.14160156, "text": "the" }, { "id": 1967, - "logprob": -0.06756592, + "logprob": -0.06915283, "text": "image" }, { "id": 29973, - "logprob": -0.15454102, + "logprob": -0.16381836, "text": "?" } ], @@ -105,19 +105,19 @@ "tokens": [ { "id": 32002, - "logprob": -0.0019140244, + "logprob": -0.0026664734, "special": true, "text": "" }, { "id": 29871, - "logprob": -8.392334e-05, + "logprob": -8.583069e-05, "special": false, "text": " " }, { "id": 13, - "logprob": -1.7881393e-05, + "logprob": -1.8119812e-05, "special": false, "text": "\n" }, @@ -135,36 +135,35 @@ }, { "id": 29901, - "logprob": -3.0994415e-06, + "logprob": -3.2186508e-06, "special": false, "text": ":" }, { "id": 319, - "logprob": -0.9057617, + "logprob": -0.9301758, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2294922, + "logprob": -1.1279297, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.00024533272, + "logprob": -0.0002939701, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1640625, + "logprob": -1.1865234, "special": false, "text": " stands" } - ], - "top_tokens": null + ] }, "generated_text": " \nAssistant: A rooster stands" }, @@ -181,92 +180,92 @@ }, { "id": 4911, - "logprob": -5.7773438, + "logprob": -6.9804688, "text": "User" }, { "id": 29901, - "logprob": -0.0070114136, + "logprob": -0.006122589, "text": ":" }, { "id": 32000, - "logprob": -0.8208008, + "logprob": -0.8417969, "text": "" }, { "id": 32001, - "logprob": -6.699562e-05, + "logprob": -9.942055e-05, "text": "" }, { "id": 32000, - "logprob": -3.5762787e-07, + "logprob": -2.3841858e-07, "text": "" }, { "id": 1815, - "logprob": -4.2265625, + "logprob": -4.1679688, "text": "Can" }, { "id": 366, - "logprob": -0.014175415, + "logprob": -0.014091492, "text": "you" }, { "id": 2649, - "logprob": -4.4296875, + "logprob": -4.4726562, "text": "tell" }, { "id": 592, - "logprob": -0.29516602, + "logprob": -0.2998047, "text": "me" }, { "id": 263, - "logprob": -4.2109375, + "logprob": -4.15625, "text": "a" }, { "id": 1407, - "logprob": -9.4296875, + "logprob": -9.3828125, "text": "very" }, { "id": 3273, - "logprob": -1.8720703, + "logprob": -1.9716797, "text": "short" }, { "id": 5828, - "logprob": -0.26879883, + "logprob": -0.27734375, "text": "story" }, { "id": 2729, - "logprob": -3.7675781, + "logprob": -3.5605469, "text": "based" }, { "id": 373, - "logprob": -0.0005354881, + "logprob": -0.0006451607, "text": "on" }, { "id": 278, - "logprob": -0.13671875, + "logprob": -0.14160156, "text": "the" }, { "id": 1967, - "logprob": -0.06719971, + "logprob": -0.06915283, "text": "image" }, { "id": 29973, - "logprob": -0.15551758, + "logprob": -0.16381836, "text": "?" } ], @@ -274,19 +273,19 @@ "tokens": [ { "id": 32002, - "logprob": -0.0019130707, + "logprob": -0.0026664734, "special": true, "text": "" }, { "id": 29871, - "logprob": -8.392334e-05, + "logprob": -8.571148e-05, "special": false, "text": " " }, { "id": 13, - "logprob": -1.7881393e-05, + "logprob": -1.8119812e-05, "special": false, "text": "\n" }, @@ -310,30 +309,29 @@ }, { "id": 319, - "logprob": -0.9013672, + "logprob": -0.9301758, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2324219, + "logprob": -1.1279297, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.0002477169, + "logprob": -0.0002939701, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1660156, + "logprob": -1.1865234, "special": false, "text": " stands" } - ], - "top_tokens": null + ] }, "generated_text": " \nAssistant: A rooster stands" }, @@ -350,92 +348,92 @@ }, { "id": 4911, - "logprob": -5.7773438, + "logprob": -6.9804688, "text": "User" }, { "id": 29901, - "logprob": -0.0070114136, + "logprob": -0.006122589, "text": ":" }, { "id": 32000, - "logprob": -0.8208008, + "logprob": -0.8417969, "text": "" }, { "id": 32001, - "logprob": -6.699562e-05, + "logprob": -9.918213e-05, "text": "" }, { "id": 32000, - "logprob": -3.5762787e-07, + "logprob": -2.3841858e-07, "text": "" }, { "id": 1815, - "logprob": -4.2265625, + "logprob": -4.1679688, "text": "Can" }, { "id": 366, - "logprob": -0.014175415, + "logprob": -0.014091492, "text": "you" }, { "id": 2649, - "logprob": -4.4296875, + "logprob": -4.4726562, "text": "tell" }, { "id": 592, - "logprob": -0.29516602, + "logprob": -0.2998047, "text": "me" }, { "id": 263, - "logprob": -4.2109375, + "logprob": -4.15625, "text": "a" }, { "id": 1407, - "logprob": -9.4296875, + "logprob": -9.3828125, "text": "very" }, { "id": 3273, - "logprob": -1.8720703, + "logprob": -1.9716797, "text": "short" }, { "id": 5828, - "logprob": -0.26879883, + "logprob": -0.27734375, "text": "story" }, { "id": 2729, - "logprob": -3.7675781, + "logprob": -3.5605469, "text": "based" }, { "id": 373, - "logprob": -0.0005354881, + "logprob": -0.00064468384, "text": "on" }, { "id": 278, - "logprob": -0.13671875, + "logprob": -0.14160156, "text": "the" }, { "id": 1967, - "logprob": -0.06719971, + "logprob": -0.06915283, "text": "image" }, { "id": 29973, - "logprob": -0.15551758, + "logprob": -0.16381836, "text": "?" } ], @@ -443,19 +441,19 @@ "tokens": [ { "id": 32002, - "logprob": -0.001912117, + "logprob": -0.0026664734, "special": true, "text": "" }, { "id": 29871, - "logprob": -8.392334e-05, + "logprob": -8.59499e-05, "special": false, "text": " " }, { "id": 13, - "logprob": -1.7762184e-05, + "logprob": -1.8119812e-05, "special": false, "text": "\n" }, @@ -479,30 +477,29 @@ }, { "id": 319, - "logprob": -0.9013672, + "logprob": -0.9301758, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2324219, + "logprob": -1.1279297, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.0002477169, + "logprob": -0.0002939701, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1660156, + "logprob": -1.1865234, "special": false, "text": " stands" } - ], - "top_tokens": null + ] }, "generated_text": " \nAssistant: A rooster stands" }, @@ -519,92 +516,92 @@ }, { "id": 4911, - "logprob": -5.7773438, + "logprob": -6.9804688, "text": "User" }, { "id": 29901, - "logprob": -0.0070114136, + "logprob": -0.006122589, "text": ":" }, { "id": 32000, - "logprob": -0.8208008, + "logprob": -0.8417969, "text": "" }, { "id": 32001, - "logprob": -6.699562e-05, + "logprob": -9.942055e-05, "text": "" }, { "id": 32000, - "logprob": -3.5762787e-07, + "logprob": -2.3841858e-07, "text": "" }, { "id": 1815, - "logprob": -4.2265625, + "logprob": -4.1679688, "text": "Can" }, { "id": 366, - "logprob": -0.014175415, + "logprob": -0.014091492, "text": "you" }, { "id": 2649, - "logprob": -4.4296875, + "logprob": -4.4726562, "text": "tell" }, { "id": 592, - "logprob": -0.29516602, + "logprob": -0.2998047, "text": "me" }, { "id": 263, - "logprob": -4.2109375, + "logprob": -4.15625, "text": "a" }, { "id": 1407, - "logprob": -9.4296875, + "logprob": -9.3828125, "text": "very" }, { "id": 3273, - "logprob": -1.8720703, + "logprob": -1.9716797, "text": "short" }, { "id": 5828, - "logprob": -0.26879883, + "logprob": -0.27734375, "text": "story" }, { "id": 2729, - "logprob": -3.7675781, + "logprob": -3.5605469, "text": "based" }, { "id": 373, - "logprob": -0.0005354881, + "logprob": -0.0006451607, "text": "on" }, { "id": 278, - "logprob": -0.13671875, + "logprob": -0.14160156, "text": "the" }, { "id": 1967, - "logprob": -0.06719971, + "logprob": -0.06915283, "text": "image" }, { "id": 29973, - "logprob": -0.15551758, + "logprob": -0.16381836, "text": "?" } ], @@ -612,19 +609,19 @@ "tokens": [ { "id": 32002, - "logprob": -0.001912117, + "logprob": -0.0026664734, "special": true, "text": "" }, { "id": 29871, - "logprob": -8.392334e-05, + "logprob": -8.571148e-05, "special": false, "text": " " }, { "id": 13, - "logprob": -1.7762184e-05, + "logprob": -1.8119812e-05, "special": false, "text": "\n" }, @@ -648,30 +645,29 @@ }, { "id": 319, - "logprob": -0.9013672, + "logprob": -0.9301758, "special": false, "text": " A" }, { "id": 696, - "logprob": -1.2324219, + "logprob": -1.1279297, "special": false, "text": " ro" }, { "id": 15664, - "logprob": -0.0002477169, + "logprob": -0.0002939701, "special": false, "text": "oster" }, { "id": 15028, - "logprob": -1.1660156, + "logprob": -1.1865234, "special": false, "text": " stands" } - ], - "top_tokens": null + ] }, "generated_text": " \nAssistant: A rooster stands" } diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py new file mode 100644 index 00000000..0a9fecd1 --- /dev/null +++ b/server/tests/utils/test_layers.py @@ -0,0 +1,64 @@ +import torch +from text_generation_server.utils.layers import ( + TensorParallelEmbedding, +) + +class ProcessGroup: + def __init__(self, rank: int, world_size: int): + self._rank = rank + self.world_size = world_size + + def size(self)->int: + return self.world_size + + def rank(self)->int: + return self._rank + +class Weights: + def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int): + self.weight = torch.arange(vocab_size*hidden_dim).float().view(vocab_size, hidden_dim) + self.process_group = ProcessGroup(rank, world_size) + + + def get_partial_sharded(self, name:str, dim: int): + assert dim == 0 + + rank = self.process_group.rank() + world_size = self.process_group.size() + size = self.weight.shape[dim] + + block_size = (size + world_size - 1) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + return self.weight[start:stop] + + def get_shape(self, name: str): + return self.weight.shape + +def test_weight_hub_files_offline_error(): + + vocab_size= 17 + weights = Weights(rank=0, world_size=1, vocab_size = vocab_size,hidden_dim = 256) + embeddings = TensorParallelEmbedding("", weights) + + input_ids = torch.arange(vocab_size) + output = embeddings.forward(input_ids) + assert embeddings.min_id == 0 + assert embeddings.max_id == 17 + torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256)) + + weights_0_2 = Weights(rank=0, world_size=2, vocab_size = vocab_size,hidden_dim = 256) + weights_1_2 = Weights(rank=1, world_size=2, vocab_size = vocab_size,hidden_dim = 256) + embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False) + assert embeddings_0_2.min_id == 0 + assert embeddings_0_2.max_id == 9 + torch.testing.assert_close(embeddings_0_2.weight , torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0).view(10, 256).float()) + embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False) + assert embeddings_1_2.min_id == 9 + assert embeddings_1_2.max_id == 17 + torch.testing.assert_close(embeddings_1_2.weight , torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0).view(9, 256).float()) + output_tp_0 = embeddings_0_2.forward(input_ids) + output_tp_1 = embeddings_1_2.forward(input_ids) + + torch.testing.assert_close(output, output_tp_0 + output_tp_1) + diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index d4fa2559..5a0de0d7 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -507,10 +507,10 @@ class TensorParallelEmbedding(nn.Module): world_size = process_group.size() rank = process_group.rank() - block_size = num_embeddings // world_size + block_size = (num_embeddings + world_size - 1) // world_size self.min_id = rank * block_size self.max_id = min(num_embeddings, (rank + 1) * block_size) - self.null_idx = block_size + self.null_idx = weight.shape[0] # Usually block_size, might be less in non even vocab_size. self.process_group = weights.process_group self.reduce = reduce diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index c4e82a6d..186733f3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -92,7 +92,7 @@ class Weights: rank = self.process_group.rank() size = slice_.get_shape()[dim] - block_size = size // world_size + block_size = (size + world_size - 1) // world_size start = rank * block_size stop = (rank + 1) * block_size