157 lines
5.3 KiB
Python
157 lines
5.3 KiB
Python
import re
|
|
import torch
|
|
import torch.distributed
|
|
|
|
|
|
from transformers import (
|
|
PreTrainedTokenizerBase,
|
|
)
|
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.utils import (
|
|
NextTokenChooser,
|
|
StoppingCriteria,
|
|
)
|
|
from text_generation_server.utils.chunks import concat_text_chunks
|
|
|
|
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
|
|
|
# we split individual characters inside special tokens like [START_DNA]
|
|
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
|
|
|
|
# token added to implement a custom sequence tokenization. This token is added at
|
|
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
|
|
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
|
|
# literally in the source code in case we ever include it in the training data.
|
|
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
|
|
|
|
|
|
def _insert_split_marker(m: re.Match):
|
|
"""
|
|
Applies split marker based on a regex match of special tokens such as
|
|
[START_DNA].
|
|
Parameters
|
|
----------
|
|
n : str
|
|
Input text to split
|
|
Returns
|
|
----------
|
|
str - the text with the split token added
|
|
"""
|
|
start_token, _, sequence, end_token = m.groups()
|
|
sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
|
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
|
|
|
|
|
def escape_custom_split_sequence(text):
|
|
"""
|
|
Applies custom splitting to the text for GALILEO's tokenization
|
|
Parameters
|
|
----------
|
|
text : str
|
|
Input text to split
|
|
Returns
|
|
----------
|
|
str - the text with the split token added
|
|
"""
|
|
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
|
|
|
|
|
# END CREDIT
|
|
|
|
|
|
class GalacticaCausalLMBatch(CausalLMBatch):
|
|
@classmethod
|
|
def from_pb(
|
|
cls,
|
|
pb: generate_pb2.Batch,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> "GalacticaCausalLMBatch":
|
|
inputs = []
|
|
next_token_choosers = []
|
|
stopping_criterias = []
|
|
prefix_offsets = []
|
|
top_n_tokens = []
|
|
read_offsets = []
|
|
requests_idx_mapping = {}
|
|
|
|
# Parse batch
|
|
max_truncation = 0
|
|
padding_right_offset = 0
|
|
max_decode_tokens = 0
|
|
for i, r in enumerate(pb.requests):
|
|
requests_idx_mapping[r.id] = i
|
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
|
inputs.append(
|
|
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
|
|
)
|
|
next_token_choosers.append(
|
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
|
)
|
|
stopping_criteria = StoppingCriteria.from_pb(
|
|
r.stopping_parameters, tokenizer
|
|
)
|
|
stopping_criterias.append(stopping_criteria)
|
|
top_n_tokens.append(r.top_n_tokens)
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
|
padding_right_offset = max(
|
|
padding_right_offset, stopping_criteria.max_new_tokens
|
|
)
|
|
|
|
tokenized_inputs = tokenizer(
|
|
inputs,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
return_token_type_ids=False,
|
|
truncation=True,
|
|
max_length=max_truncation,
|
|
).to(device)
|
|
for _ in pb.requests:
|
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
|
prefix_offsets.append(0)
|
|
read_offsets.append(input_len)
|
|
|
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
|
max_input_length = input_lengths.max()
|
|
|
|
input_ids = tokenized_inputs["input_ids"]
|
|
# Allocate maximum attention_mask
|
|
attention_mask = input_ids.new_zeros(
|
|
(pb.size, max_input_length + padding_right_offset)
|
|
)
|
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
|
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
|
|
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
|
top_n_tokens_tensor = torch.tensor(
|
|
top_n_tokens, device=device, dtype=torch.int64
|
|
)
|
|
|
|
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
|
|
|
return cls(
|
|
batch_id=pb.id,
|
|
requests=pb.requests,
|
|
requests_idx_mapping=requests_idx_mapping,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=None,
|
|
all_input_ids=list(all_input_ids),
|
|
input_lengths=input_lengths.tolist(),
|
|
prefix_offsets=prefix_offsets,
|
|
read_offsets=read_offsets,
|
|
next_token_choosers=next_token_choosers,
|
|
stopping_criterias=stopping_criterias,
|
|
top_n_tokens=top_n_tokens,
|
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
|
max_input_length=max_input_length.item(),
|
|
padding_right_offset=padding_right_offset,
|
|
max_tokens=max_tokens,
|
|
)
|