2022-11-04 11:03:04 -06:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
2023-02-13 05:02:45 -07:00
|
|
|
from opentelemetry import trace
|
2023-01-17 01:10:22 -07:00
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
2022-11-04 11:03:04 -06:00
|
|
|
from typing import Optional, Tuple, List, Type
|
|
|
|
|
|
|
|
from text_generation.models import Model
|
2023-01-31 09:04:00 -07:00
|
|
|
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
2022-11-04 11:03:04 -06:00
|
|
|
from text_generation.pb import generate_pb2
|
2023-01-30 07:36:16 -07:00
|
|
|
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
2022-11-04 11:03:04 -06:00
|
|
|
|
2023-02-13 05:02:45 -07:00
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
@dataclass
|
2023-01-17 01:10:22 -07:00
|
|
|
class Seq2SeqLMBatch(Batch):
|
2022-11-04 11:03:04 -06:00
|
|
|
batch_id: int
|
|
|
|
requests: List[generate_pb2.Request]
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Encoder values
|
2022-11-04 11:03:04 -06:00
|
|
|
input_ids: torch.Tensor
|
|
|
|
attention_mask: torch.Tensor
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Decoder values
|
2022-11-04 11:03:04 -06:00
|
|
|
decoder_input_ids: torch.Tensor
|
|
|
|
decoder_attention_mask: Optional[torch.Tensor]
|
|
|
|
encoder_last_hidden_state: Optional[torch.Tensor]
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
|
2022-11-04 11:03:04 -06:00
|
|
|
past_key_values: Optional[List[Tuple]]
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Lengths of all generations present in the batch
|
2022-11-04 11:03:04 -06:00
|
|
|
input_lengths: List[int]
|
|
|
|
decoder_input_lengths: List[int]
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Generation helpers
|
2022-11-04 11:03:04 -06:00
|
|
|
next_token_choosers: List[NextTokenChooser]
|
|
|
|
stopping_criterias: List[StoppingCriteria]
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Metadata used for padding
|
2022-11-04 11:03:04 -06:00
|
|
|
size: int
|
|
|
|
max_input_length: int
|
|
|
|
max_decoder_input_length: int
|
2023-02-24 04:49:21 -07:00
|
|
|
padding_right_offset: int
|
2022-11-04 11:03:04 -06:00
|
|
|
|
2023-01-17 01:10:22 -07:00
|
|
|
def to_pb(self) -> generate_pb2.Batch:
|
2022-11-07 04:53:56 -07:00
|
|
|
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
2022-11-04 11:03:04 -06:00
|
|
|
return generate_pb2.Batch(
|
|
|
|
id=self.batch_id,
|
|
|
|
requests=self.requests,
|
|
|
|
size=self.size,
|
|
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pb(
|
2023-01-20 04:24:39 -07:00
|
|
|
cls,
|
|
|
|
pb: generate_pb2.Batch,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
device: torch.device,
|
2022-11-04 11:03:04 -06:00
|
|
|
) -> "Seq2SeqLMBatch":
|
2022-11-07 04:53:56 -07:00
|
|
|
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
2022-11-04 11:03:04 -06:00
|
|
|
inputs = []
|
|
|
|
next_token_choosers = []
|
|
|
|
stopping_criterias = []
|
|
|
|
input_lengths = []
|
|
|
|
|
|
|
|
decoder_input_ids = []
|
|
|
|
decoder_input_lengths = []
|
|
|
|
|
|
|
|
# Parse batch
|
2023-02-24 04:49:21 -07:00
|
|
|
max_input_length = 0
|
|
|
|
padding_right_offset = 0
|
2022-11-04 11:03:04 -06:00
|
|
|
for r in pb.requests:
|
|
|
|
inputs.append(r.inputs)
|
|
|
|
input_lengths.append(r.input_length)
|
2022-11-07 04:53:56 -07:00
|
|
|
# Decoder sequence only contains the bos_token
|
2022-11-04 11:03:04 -06:00
|
|
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
|
|
|
decoder_input_lengths.append(1)
|
2023-01-31 06:30:33 -07:00
|
|
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
2023-02-24 04:49:21 -07:00
|
|
|
stopping_criteria = StoppingCriteria.from_pb(
|
|
|
|
r.stopping_parameters, tokenizer
|
|
|
|
)
|
|
|
|
stopping_criterias.append(stopping_criteria)
|
|
|
|
max_input_length = max(max_input_length, r.input_length)
|
|
|
|
padding_right_offset = max(
|
|
|
|
padding_right_offset, stopping_criteria.max_new_tokens
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Tokenize batch
|
2022-11-04 11:03:04 -06:00
|
|
|
tokenized_inputs = tokenizer(
|
2022-12-12 10:25:22 -07:00
|
|
|
inputs,
|
|
|
|
return_tensors="pt",
|
|
|
|
padding=True,
|
2023-01-20 04:24:39 -07:00
|
|
|
return_token_type_ids=False,
|
2022-11-04 11:03:04 -06:00
|
|
|
).to(device)
|
2022-11-07 04:53:56 -07:00
|
|
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
2022-12-05 02:10:59 -07:00
|
|
|
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
return cls(
|
|
|
|
batch_id=pb.id,
|
|
|
|
requests=pb.requests,
|
|
|
|
input_ids=tokenized_inputs["input_ids"],
|
|
|
|
attention_mask=tokenized_inputs["attention_mask"],
|
|
|
|
decoder_input_ids=decoder_input_ids,
|
|
|
|
decoder_attention_mask=None,
|
|
|
|
encoder_last_hidden_state=None,
|
|
|
|
past_key_values=None,
|
|
|
|
input_lengths=input_lengths,
|
|
|
|
decoder_input_lengths=decoder_input_lengths,
|
|
|
|
next_token_choosers=next_token_choosers,
|
|
|
|
stopping_criterias=stopping_criterias,
|
|
|
|
size=len(pb.requests),
|
|
|
|
max_input_length=max(input_lengths),
|
|
|
|
max_decoder_input_length=1,
|
2023-02-24 04:49:21 -07:00
|
|
|
padding_right_offset=padding_right_offset,
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
2023-02-13 05:02:45 -07:00
|
|
|
@tracer.start_as_current_span("concatenate")
|
2022-11-04 11:03:04 -06:00
|
|
|
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
2022-11-07 04:53:56 -07:00
|
|
|
"""Concatenate multiple batches together by padding internal torch tensors"""
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
# Used for padding
|
2023-02-24 04:49:21 -07:00
|
|
|
total_batch_size = 0
|
|
|
|
max_input_length = 0
|
|
|
|
max_decoder_input_length = 0
|
|
|
|
padding_right_offset = 0
|
|
|
|
for batch in batches:
|
|
|
|
total_batch_size += batch.size
|
|
|
|
max_input_length = max(max_input_length, batch.max_input_length)
|
|
|
|
max_decoder_input_length = max(
|
|
|
|
max_decoder_input_length, batch.max_decoder_input_length
|
|
|
|
)
|
|
|
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
# Batch attributes
|
|
|
|
requests = []
|
|
|
|
input_lengths = []
|
|
|
|
decoder_input_lengths = []
|
|
|
|
next_token_choosers = []
|
|
|
|
stopping_criterias = []
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Batch tensors
|
2022-11-04 11:03:04 -06:00
|
|
|
attention_mask = None
|
|
|
|
decoder_input_ids = None
|
|
|
|
decoder_attention_mask = None
|
|
|
|
encoder_last_hidden_state = None
|
|
|
|
past_key_values = []
|
|
|
|
|
|
|
|
# Used for slicing correctly inside the tensors
|
|
|
|
# Equivalent to a cumsum on batch sizes
|
|
|
|
start_index = 0
|
2022-11-07 04:53:56 -07:00
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
for i, batch in enumerate(batches):
|
2022-11-07 04:53:56 -07:00
|
|
|
# Extend all list attributes
|
2022-11-04 11:03:04 -06:00
|
|
|
requests.extend(batch.requests)
|
|
|
|
input_lengths.extend(batch.input_lengths)
|
|
|
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
|
|
|
next_token_choosers.extend(batch.next_token_choosers)
|
|
|
|
stopping_criterias.extend(batch.stopping_criterias)
|
|
|
|
|
|
|
|
# Slicing end index for this batch
|
|
|
|
end_index = start_index + batch.size
|
|
|
|
|
|
|
|
# We only concatenate batches that did at least one step
|
|
|
|
if batch.encoder_last_hidden_state is None:
|
|
|
|
raise ValueError("Batch encoder_last_hidden_state cannot be None")
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Create padded tensor
|
2022-11-04 11:03:04 -06:00
|
|
|
if attention_mask is None:
|
2023-01-17 01:10:22 -07:00
|
|
|
attention_mask = batch.attention_mask.new_zeros(
|
2022-11-04 11:03:04 -06:00
|
|
|
(total_batch_size, max_input_length),
|
|
|
|
)
|
2022-11-07 04:53:56 -07:00
|
|
|
# Copy to correct indices
|
2022-11-04 11:03:04 -06:00
|
|
|
attention_mask[
|
|
|
|
start_index:end_index, -batch.max_input_length :
|
|
|
|
] = batch.attention_mask[:, -batch.max_input_length :]
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Create padded tensor
|
2022-11-04 11:03:04 -06:00
|
|
|
if decoder_input_ids is None:
|
2023-01-17 01:10:22 -07:00
|
|
|
decoder_input_ids = batch.decoder_input_ids.new_zeros(
|
2022-11-04 11:03:04 -06:00
|
|
|
(total_batch_size, max_decoder_input_length),
|
|
|
|
)
|
2022-11-07 04:53:56 -07:00
|
|
|
# Copy to correct indices
|
2022-11-04 11:03:04 -06:00
|
|
|
decoder_input_ids[
|
|
|
|
start_index:end_index, -batch.max_decoder_input_length :
|
|
|
|
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Create padded tensor
|
2022-11-04 11:03:04 -06:00
|
|
|
if decoder_attention_mask is None:
|
2023-01-17 01:10:22 -07:00
|
|
|
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
|
|
|
|
decoder_attention_mask = batch.attention_mask.new_zeros(
|
2023-02-24 04:49:21 -07:00
|
|
|
(total_batch_size, max_decoder_input_length + padding_right_offset),
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
2022-11-07 04:53:56 -07:00
|
|
|
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
|
|
|
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
2023-02-24 04:49:21 -07:00
|
|
|
left_offset = max_decoder_input_length - batch.max_decoder_input_length
|
2022-11-04 11:03:04 -06:00
|
|
|
if batch.decoder_attention_mask is None:
|
|
|
|
decoder_attention_mask[
|
2023-02-24 04:49:21 -07:00
|
|
|
start_index:end_index,
|
|
|
|
left_offset:-padding_right_offset,
|
2022-11-04 11:03:04 -06:00
|
|
|
] = 1
|
2022-11-07 04:53:56 -07:00
|
|
|
# If it exists, we need to index
|
2022-11-04 11:03:04 -06:00
|
|
|
else:
|
2023-02-24 04:49:21 -07:00
|
|
|
batch_left_offset = (
|
|
|
|
batch.decoder_attention_mask.shape[1]
|
|
|
|
- batch.max_decoder_input_length - batch.padding_right_offset
|
|
|
|
)
|
2022-11-04 11:03:04 -06:00
|
|
|
decoder_attention_mask[
|
2023-02-24 04:49:21 -07:00
|
|
|
start_index:end_index,
|
|
|
|
left_offset:-padding_right_offset,
|
|
|
|
] = batch.decoder_attention_mask[
|
|
|
|
:,
|
|
|
|
batch_left_offset : -batch.padding_right_offset,
|
|
|
|
]
|
2022-11-04 11:03:04 -06:00
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Create padded tensor
|
2022-11-04 11:03:04 -06:00
|
|
|
if encoder_last_hidden_state is None:
|
2023-01-17 01:10:22 -07:00
|
|
|
encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
|
2022-11-04 11:03:04 -06:00
|
|
|
(
|
|
|
|
total_batch_size,
|
|
|
|
max_input_length,
|
|
|
|
batch.encoder_last_hidden_state.shape[-1],
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Copy to correct indices
|
2022-11-04 11:03:04 -06:00
|
|
|
encoder_last_hidden_state[
|
2022-12-08 10:49:33 -07:00
|
|
|
start_index:end_index, -batch.max_input_length :, :
|
|
|
|
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
|
2022-11-04 11:03:04 -06:00
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Iterate over attention layers
|
2022-11-04 11:03:04 -06:00
|
|
|
for j, past in enumerate(batch.past_key_values):
|
|
|
|
_, num_heads, _, head_dim = past[0].shape
|
|
|
|
|
|
|
|
# This will run only once per layer
|
|
|
|
if j == len(past_key_values):
|
|
|
|
past_key_values.append([])
|
|
|
|
|
|
|
|
# Decoder past
|
|
|
|
for k, t in enumerate(past[:2]):
|
|
|
|
padded_t_shape = (
|
|
|
|
total_batch_size,
|
|
|
|
num_heads,
|
|
|
|
(max_decoder_input_length - 1),
|
|
|
|
head_dim,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Initialize tensors
|
|
|
|
# This will run only once per layer and per past tensor
|
|
|
|
if k == len(past_key_values[j]):
|
2023-01-17 01:10:22 -07:00
|
|
|
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
# We slice the past keys and values to remove the padding from previous batches
|
|
|
|
past_key_values[j][k][
|
|
|
|
start_index:end_index,
|
|
|
|
:,
|
|
|
|
-(batch.max_decoder_input_length - 1) :,
|
|
|
|
:,
|
|
|
|
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
|
|
|
|
|
|
|
|
# encoder past
|
|
|
|
for k, t in enumerate(past[2:]):
|
|
|
|
padded_t_shape = (
|
|
|
|
total_batch_size,
|
|
|
|
num_heads,
|
|
|
|
max_input_length,
|
|
|
|
head_dim,
|
|
|
|
)
|
|
|
|
|
|
|
|
idx = k + 2
|
|
|
|
|
|
|
|
# Initialize tensors
|
|
|
|
# This will run only once per layer and per past tensor
|
|
|
|
if idx == len(past_key_values[j]):
|
2023-01-17 01:10:22 -07:00
|
|
|
past_key_values[j].append(t.new_zeros(padded_t_shape))
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
past_key_values[j][idx][
|
|
|
|
start_index:end_index, :, -batch.max_input_length :, :
|
|
|
|
] = t[:, :, -batch.max_input_length :, :]
|
|
|
|
|
|
|
|
start_index += batch.size
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
batch_id=batches[0].batch_id,
|
|
|
|
requests=requests,
|
2023-02-24 04:49:21 -07:00
|
|
|
input_ids=None,
|
2022-11-04 11:03:04 -06:00
|
|
|
attention_mask=attention_mask,
|
|
|
|
decoder_input_ids=decoder_input_ids,
|
|
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
|
|
encoder_last_hidden_state=encoder_last_hidden_state,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
input_lengths=input_lengths,
|
|
|
|
decoder_input_lengths=decoder_input_lengths,
|
|
|
|
next_token_choosers=next_token_choosers,
|
|
|
|
stopping_criterias=stopping_criterias,
|
|
|
|
size=total_batch_size,
|
|
|
|
max_input_length=max_input_length,
|
|
|
|
max_decoder_input_length=max_decoder_input_length,
|
2023-02-24 04:49:21 -07:00
|
|
|
padding_right_offset=padding_right_offset,
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
|
|
|
|
2023-01-31 09:04:00 -07:00
|
|
|
def __len__(self):
|
|
|
|
return len(self.requests)
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
class Seq2SeqLM(Model):
|
2023-02-03 04:43:37 -07:00
|
|
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
2022-11-04 11:03:04 -06:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
device = torch.device("cuda")
|
|
|
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
|
|
|
else:
|
2022-12-08 10:49:33 -07:00
|
|
|
if quantize:
|
|
|
|
raise ValueError("quantization is not available on CPU")
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
device = torch.device("cpu")
|
|
|
|
dtype = torch.float32
|
|
|
|
|
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
2023-02-03 04:43:37 -07:00
|
|
|
model_id,
|
2023-01-31 10:53:56 -07:00
|
|
|
revision=revision,
|
2022-11-04 11:03:04 -06:00
|
|
|
torch_dtype=dtype,
|
|
|
|
device_map="auto" if torch.cuda.is_available() else None,
|
2022-11-07 04:53:56 -07:00
|
|
|
load_in_8bit=quantize,
|
2022-11-04 11:03:04 -06:00
|
|
|
).eval()
|
2023-01-31 10:53:56 -07:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
2023-02-03 04:43:37 -07:00
|
|
|
model_id, revision=revision, padding_side="left"
|
2023-01-31 10:53:56 -07:00
|
|
|
)
|
2022-11-04 11:03:04 -06:00
|
|
|
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
|
|
|
|
|
|
|
super(Seq2SeqLM, self).__init__(
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def batch_type(self) -> Type[Seq2SeqLMBatch]:
|
|
|
|
return Seq2SeqLMBatch
|
|
|
|
|
2023-01-20 04:24:39 -07:00
|
|
|
def decode(self, decoder_ids: List[int]) -> str:
|
|
|
|
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
input_ids,
|
|
|
|
attention_mask,
|
|
|
|
decoder_input_ids,
|
|
|
|
decoder_attention_mask: Optional,
|
|
|
|
encoder_last_hidden_state: Optional,
|
|
|
|
past_key_values: Optional = None,
|
|
|
|
) -> Tuple[
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
|
|
]:
|
|
|
|
# Model Forward
|
|
|
|
outputs = self.model.forward(
|
|
|
|
input_ids=input_ids,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
decoder_input_ids=decoder_input_ids,
|
|
|
|
decoder_attention_mask=decoder_attention_mask,
|
2022-11-07 04:53:56 -07:00
|
|
|
encoder_outputs=encoder_last_hidden_state,
|
2022-11-04 11:03:04 -06:00
|
|
|
past_key_values=past_key_values,
|
|
|
|
use_cache=True,
|
|
|
|
)
|
|
|
|
return (
|
|
|
|
outputs.logits,
|
|
|
|
outputs.encoder_last_hidden_state,
|
|
|
|
outputs.past_key_values,
|
|
|
|
)
|
|
|
|
|
2023-02-13 05:02:45 -07:00
|
|
|
@tracer.start_as_current_span("generate_token")
|
2022-11-04 11:03:04 -06:00
|
|
|
def generate_token(
|
|
|
|
self, batch: Seq2SeqLMBatch
|
2023-01-31 09:04:00 -07:00
|
|
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
2023-02-24 04:49:21 -07:00
|
|
|
if batch.decoder_attention_mask is not None:
|
|
|
|
# slice to the correct shape
|
|
|
|
decoder_attention_mask = batch.decoder_attention_mask[
|
|
|
|
:, : -batch.padding_right_offset
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
decoder_attention_mask = None
|
|
|
|
|
|
|
|
# check if first forward or not
|
|
|
|
if batch.past_key_values is not None:
|
|
|
|
# Only take the last token
|
|
|
|
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
|
|
|
|
else:
|
|
|
|
decoder_input_ids = batch.decoder_input_ids
|
|
|
|
|
|
|
|
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
|
|
|
# internally...
|
|
|
|
if batch.encoder_last_hidden_state is not None:
|
|
|
|
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
|
|
|
|
else:
|
|
|
|
encoder_last_hidden_state = batch.encoder_last_hidden_state
|
|
|
|
|
2023-02-07 07:38:22 -07:00
|
|
|
logits, encoder_last_hidden_state, past = self.forward(
|
|
|
|
batch.input_ids,
|
|
|
|
batch.attention_mask,
|
2023-02-24 04:49:21 -07:00
|
|
|
decoder_input_ids,
|
|
|
|
decoder_attention_mask,
|
|
|
|
encoder_last_hidden_state,
|
2023-02-07 07:38:22 -07:00
|
|
|
batch.past_key_values,
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
# List of indices to cache
|
|
|
|
next_batch_keep_indices = []
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# New values for next forward
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_input_lengths = []
|
|
|
|
next_batch_decoder_input_ids = []
|
|
|
|
next_batch_decoder_input_lengths = []
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Metadata
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_size = 0
|
|
|
|
next_batch_max_input_length = 0
|
|
|
|
next_batch_max_decoder_input_length = 0
|
|
|
|
|
|
|
|
# Finished requests
|
2023-01-31 09:04:00 -07:00
|
|
|
generations: List[Generation] = []
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
# Zipped iterator
|
|
|
|
iterator = zip(
|
|
|
|
batch.requests,
|
|
|
|
batch.input_lengths,
|
|
|
|
batch.decoder_input_lengths,
|
|
|
|
logits,
|
|
|
|
batch.next_token_choosers,
|
|
|
|
batch.stopping_criterias,
|
|
|
|
batch.decoder_input_ids,
|
|
|
|
)
|
|
|
|
|
|
|
|
# For each member of the batch
|
|
|
|
for i, (
|
|
|
|
request,
|
|
|
|
input_length,
|
|
|
|
decoder_input_length,
|
|
|
|
logits,
|
|
|
|
next_token_chooser,
|
|
|
|
stopping_criteria,
|
2022-12-15 09:03:56 -07:00
|
|
|
decoder_input_ids,
|
2022-11-04 11:03:04 -06:00
|
|
|
) in enumerate(iterator):
|
|
|
|
# Select next token
|
2023-02-01 07:58:42 -07:00
|
|
|
next_token_id, logprobs = next_token_chooser(
|
|
|
|
decoder_input_ids.view(1, -1), logits
|
|
|
|
)
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
# Append next token to decoder tokens
|
2023-02-13 05:02:45 -07:00
|
|
|
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.squeeze(1)])
|
2022-12-15 09:03:56 -07:00
|
|
|
new_decoder_input_length = decoder_input_length + 1
|
|
|
|
|
2023-01-31 09:04:00 -07:00
|
|
|
# Generated token
|
|
|
|
next_token_logprob = logprobs[-1, next_token_id]
|
|
|
|
next_token_id_squeezed = next_token_id.squeeze()
|
|
|
|
next_token_text = self.tokenizer.decode(
|
|
|
|
next_token_id_squeezed,
|
|
|
|
clean_up_tokenization_spaces=False,
|
|
|
|
skip_special_tokens=False,
|
|
|
|
)
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
# Evaluate stopping criteria
|
2023-01-31 09:04:00 -07:00
|
|
|
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
|
|
|
|
2022-12-12 10:25:22 -07:00
|
|
|
if stop:
|
2022-12-15 09:03:56 -07:00
|
|
|
# Slice with decoder_input_length to remove padding
|
|
|
|
# Decode all tokens
|
2023-01-31 09:04:00 -07:00
|
|
|
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
2023-01-30 07:36:16 -07:00
|
|
|
|
|
|
|
# Get seed
|
|
|
|
if isinstance(next_token_chooser.choice, Sampling):
|
|
|
|
seed = next_token_chooser.choice.seed
|
|
|
|
else:
|
|
|
|
seed = None
|
|
|
|
|
2023-01-31 09:04:00 -07:00
|
|
|
generated_text = GeneratedText(
|
|
|
|
output_text, stopping_criteria.current_tokens, reason, seed
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
|
|
|
else:
|
2023-01-31 09:04:00 -07:00
|
|
|
# Keep request in the batch
|
|
|
|
generated_text = None
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_keep_indices.append(i)
|
2022-12-15 09:03:56 -07:00
|
|
|
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_size += 1
|
|
|
|
next_batch_input_lengths.append(input_length)
|
|
|
|
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
|
|
|
next_batch_max_input_length = max(
|
|
|
|
next_batch_max_input_length, input_length
|
|
|
|
)
|
|
|
|
next_batch_max_decoder_input_length = max(
|
|
|
|
next_batch_max_decoder_input_length, new_decoder_input_length
|
|
|
|
)
|
|
|
|
|
2023-01-31 09:04:00 -07:00
|
|
|
# Prefill
|
|
|
|
if stopping_criteria.current_tokens == 1:
|
|
|
|
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
|
|
|
prefill_texts = self.tokenizer.batch_decode(
|
|
|
|
prefill_token_ids,
|
|
|
|
clean_up_tokenization_spaces=False,
|
|
|
|
skip_special_tokens=False,
|
|
|
|
)
|
|
|
|
prefill_tokens = PrefillTokens(
|
|
|
|
prefill_token_ids, [float("nan")], prefill_texts
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
prefill_tokens = None
|
|
|
|
|
|
|
|
generation = Generation(
|
|
|
|
request.id,
|
|
|
|
prefill_tokens,
|
|
|
|
next_token_id_squeezed,
|
|
|
|
next_token_logprob,
|
|
|
|
next_token_text,
|
|
|
|
generated_text,
|
|
|
|
)
|
|
|
|
|
|
|
|
generations.append(generation)
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
# We finished all generations in the batch; there is no next batch
|
|
|
|
if not next_batch_keep_indices:
|
2023-01-31 09:04:00 -07:00
|
|
|
return generations, None
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
2022-11-07 04:53:56 -07:00
|
|
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
|
|
|
# from the values of the next batch
|
2023-01-31 09:04:00 -07:00
|
|
|
if len(next_batch_keep_indices) != len(batch):
|
2023-02-24 04:49:21 -07:00
|
|
|
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
|
|
|
if batch.decoder_attention_mask is not None:
|
|
|
|
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
|
|
|
|
next_batch_keep_indices
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
next_batch_decoder_attention_mask = None
|
|
|
|
|
|
|
|
next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
|
|
|
|
next_batch_keep_indices
|
|
|
|
]
|
|
|
|
|
|
|
|
next_batch_past_key_values = [
|
|
|
|
[t[next_batch_keep_indices] for t in layer] for layer in past
|
|
|
|
]
|
|
|
|
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
|
|
|
next_batch_next_token_choosers = [
|
|
|
|
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
|
|
|
]
|
|
|
|
next_batch_stopping_criterias = [
|
|
|
|
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
next_batch_attention_mask = batch.attention_mask
|
|
|
|
next_batch_decoder_attention_mask = batch.decoder_attention_mask
|
|
|
|
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
|
|
|
|
next_batch_past_key_values = past
|
|
|
|
|
|
|
|
next_batch_requests = batch.requests
|
|
|
|
next_batch_next_token_choosers = batch.next_token_choosers
|
|
|
|
next_batch_stopping_criterias = batch.stopping_criterias
|
|
|
|
|
2023-02-24 04:49:21 -07:00
|
|
|
# Update decoder_attention_mask as we added a new token to input_ids
|
2022-11-04 11:03:04 -06:00
|
|
|
if next_batch_decoder_attention_mask is not None:
|
2023-02-24 04:49:21 -07:00
|
|
|
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
next_batch = Seq2SeqLMBatch(
|
|
|
|
batch_id=batch.batch_id,
|
|
|
|
requests=next_batch_requests,
|
2023-02-24 04:49:21 -07:00
|
|
|
input_ids=None,
|
2022-11-04 11:03:04 -06:00
|
|
|
attention_mask=next_batch_attention_mask,
|
|
|
|
decoder_input_ids=next_batch_decoder_input_ids,
|
|
|
|
decoder_attention_mask=next_batch_decoder_attention_mask,
|
|
|
|
encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
|
|
|
|
past_key_values=next_batch_past_key_values,
|
|
|
|
input_lengths=next_batch_input_lengths,
|
|
|
|
decoder_input_lengths=next_batch_decoder_input_lengths,
|
|
|
|
next_token_choosers=next_batch_next_token_choosers,
|
|
|
|
stopping_criterias=next_batch_stopping_criterias,
|
|
|
|
size=next_batch_size,
|
|
|
|
max_input_length=next_batch_max_input_length,
|
|
|
|
max_decoder_input_length=next_batch_max_decoder_input_length,
|
2023-02-24 04:49:21 -07:00
|
|
|
padding_right_offset=batch.padding_right_offset - 1,
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
2023-01-31 09:04:00 -07:00
|
|
|
return generations, next_batch
|