diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index cfd5e658..b296c96e 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -80,6 +80,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): next_token_choosers = [] stopping_criterias = [] prefix_offsets = [] + top_n_tokens = [] read_offsets = [] requests_idx_mapping = {} @@ -96,6 +97,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( @@ -129,6 +131,9 @@ class GalacticaCausalLMBatch(CausalLMBatch): position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) max_tokens = len(inputs) * max_input_length + max_decode_tokens @@ -146,6 +151,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens,