From 123749a3c999e32db798667041a4a9589d217c8e Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Thu, 21 Sep 2023 08:15:59 +0200 Subject: [PATCH] Fix missing arguments in Galactica's from_pb (#1022) # What does this PR do? Fixes #1004 Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/models/galactica.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index cfd5e658..b296c96e 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -80,6 +80,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): next_token_choosers = [] stopping_criterias = [] prefix_offsets = [] + top_n_tokens = [] read_offsets = [] requests_idx_mapping = {} @@ -96,6 +97,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( @@ -129,6 +131,9 @@ class GalacticaCausalLMBatch(CausalLMBatch): position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) max_tokens = len(inputs) * max_input_length + max_decode_tokens @@ -146,6 +151,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens,