2022-11-04 07:22:47 -06:00
|
|
|
import torch
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
from dataclasses import dataclass
|
2022-11-04 07:22:47 -06:00
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2022-11-07 04:53:56 -07:00
|
|
|
from typing import Optional, Tuple, List, Type
|
2022-11-04 07:22:47 -06:00
|
|
|
|
|
|
|
from text_generation.models import Model
|
2022-11-04 11:03:04 -06:00
|
|
|
from text_generation.models.types import GeneratedText
|
|
|
|
from text_generation.pb import generate_pb2
|
|
|
|
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class CausalLMBatch:
|
|
|
|
batch_id: int
|
|
|
|
requests: List[generate_pb2.Request]
|
2022-11-07 04:53:56 -07:00
|
|
|
|
|
|
|
# Decoder values
|
|
|
|
input_ids: torch.Tensor
|
|
|
|
attention_mask: torch.Tensor
|
|
|
|
past_key_values: Optional[List[Tuple]]
|
|
|
|
|
|
|
|
# All tokens
|
2022-11-04 11:03:04 -06:00
|
|
|
all_input_ids: List[torch.Tensor]
|
2022-11-07 04:53:56 -07:00
|
|
|
|
|
|
|
# Lengths of all generations present in the batch
|
|
|
|
input_lengths: List[int]
|
|
|
|
|
|
|
|
# Generation helpers
|
2022-11-04 11:03:04 -06:00
|
|
|
next_token_choosers: List[NextTokenChooser]
|
|
|
|
stopping_criterias: List[StoppingCriteria]
|
2022-11-07 04:53:56 -07:00
|
|
|
|
|
|
|
# Metadata used for padding
|
2022-11-04 11:03:04 -06:00
|
|
|
size: int
|
|
|
|
max_sequence_length: int
|
|
|
|
|
|
|
|
def to_pb(self):
|
|
|
|
return generate_pb2.Batch(
|
|
|
|
id=self.batch_id,
|
|
|
|
requests=self.requests,
|
|
|
|
size=self.size,
|
|
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pb(
|
|
|
|
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
|
|
|
) -> "CausalLMBatch":
|
|
|
|
inputs = []
|
|
|
|
next_token_choosers = []
|
|
|
|
stopping_criterias = []
|
2022-11-07 04:53:56 -07:00
|
|
|
input_lengths = []
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
# Parse batch
|
|
|
|
for r in pb.requests:
|
|
|
|
inputs.append(r.inputs)
|
2022-11-07 04:53:56 -07:00
|
|
|
input_lengths.append(r.input_length)
|
2022-11-04 11:03:04 -06:00
|
|
|
next_token_choosers.append(
|
|
|
|
NextTokenChooser(
|
|
|
|
temperature=r.parameters.temperature,
|
|
|
|
top_k=r.parameters.top_k,
|
|
|
|
top_p=r.parameters.top_p,
|
|
|
|
do_sample=r.parameters.do_sample,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
stopping_criterias.append(
|
|
|
|
StoppingCriteria(
|
|
|
|
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
tokenized_inputs = tokenizer(
|
2022-11-04 11:03:04 -06:00
|
|
|
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
|
|
|
).to(device)
|
2022-11-07 04:53:56 -07:00
|
|
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
return cls(
|
|
|
|
batch_id=pb.id,
|
|
|
|
requests=pb.requests,
|
2022-11-07 04:53:56 -07:00
|
|
|
input_ids=tokenized_inputs["input_ids"],
|
|
|
|
attention_mask=tokenized_inputs["attention_mask"],
|
|
|
|
past_key_values=None,
|
2022-11-04 11:03:04 -06:00
|
|
|
all_input_ids=all_input_ids,
|
2022-11-07 04:53:56 -07:00
|
|
|
input_lengths=input_lengths,
|
2022-11-04 11:03:04 -06:00
|
|
|
next_token_choosers=next_token_choosers,
|
|
|
|
stopping_criterias=stopping_criterias,
|
|
|
|
size=pb.size,
|
2022-11-07 04:53:56 -07:00
|
|
|
max_sequence_length=max(input_lengths),
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
|
|
|
# Used for padding
|
|
|
|
total_batch_size = sum(batch.size for batch in batches)
|
|
|
|
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
|
|
|
|
|
|
|
# Batch attributes
|
|
|
|
requests = []
|
2022-11-07 04:53:56 -07:00
|
|
|
input_lengths = []
|
2022-11-04 11:03:04 -06:00
|
|
|
all_input_ids = []
|
|
|
|
next_token_choosers = []
|
|
|
|
stopping_criterias = []
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Batch tensors
|
|
|
|
input_ids = None
|
|
|
|
attention_mask = None
|
|
|
|
past_key_values = []
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
# Used for slicing correctly inside the tensors
|
|
|
|
# Equivalent to a cumsum on batch sizes
|
|
|
|
start_index = 0
|
|
|
|
for i, batch in enumerate(batches):
|
|
|
|
requests.extend(batch.requests)
|
2022-11-07 04:53:56 -07:00
|
|
|
input_lengths.extend(batch.input_lengths)
|
2022-11-04 11:03:04 -06:00
|
|
|
all_input_ids.extend(batch.all_input_ids)
|
|
|
|
next_token_choosers.extend(batch.next_token_choosers)
|
|
|
|
stopping_criterias.extend(batch.stopping_criterias)
|
|
|
|
|
|
|
|
# Slicing end index for this batch
|
|
|
|
end_index = start_index + batch.size
|
|
|
|
|
|
|
|
# We only concatenate batches that did at least one step
|
2022-11-07 04:53:56 -07:00
|
|
|
if batch.input_ids.shape[1] > 1:
|
2022-11-04 11:03:04 -06:00
|
|
|
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Create empty tensor
|
|
|
|
# input_ids is always of shape [batch_size, 1]
|
|
|
|
# We do not need to pad it
|
|
|
|
if input_ids is None:
|
|
|
|
input_ids = torch.empty(
|
2022-11-04 11:03:04 -06:00
|
|
|
(total_batch_size, 1),
|
2022-11-07 04:53:56 -07:00
|
|
|
dtype=batch.input_ids.dtype,
|
|
|
|
device=batch.input_ids.device,
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
2022-11-07 04:53:56 -07:00
|
|
|
# Copy to correct indices
|
|
|
|
input_ids[start_index:end_index] = batch.input_ids
|
|
|
|
|
|
|
|
# Create padded tensor
|
|
|
|
if attention_mask is None:
|
|
|
|
attention_mask = torch.zeros(
|
2022-11-04 11:03:04 -06:00
|
|
|
(total_batch_size, max_sequence_length),
|
2022-11-07 04:53:56 -07:00
|
|
|
dtype=batch.attention_mask.dtype,
|
|
|
|
device=batch.attention_mask.device,
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
# We need to slice the attention mask to remove padding from previous steps
|
2022-11-07 04:53:56 -07:00
|
|
|
attention_mask[
|
2022-11-04 11:03:04 -06:00
|
|
|
start_index:end_index, -batch.max_sequence_length :
|
2022-11-07 04:53:56 -07:00
|
|
|
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
2022-11-04 11:03:04 -06:00
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
for j, past in enumerate(batch.past_key_values):
|
2022-11-09 10:24:07 -07:00
|
|
|
past_keys, past_values = past
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
2022-11-09 10:24:07 -07:00
|
|
|
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
|
|
|
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
|
|
|
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
|
|
|
|
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
|
|
|
|
|
|
|
|
_, num_heads, head_dim, padded_sequence_length = past_keys.shape
|
|
|
|
|
|
|
|
padded_past_keys_shape = (
|
|
|
|
total_batch_size,
|
|
|
|
num_heads,
|
|
|
|
head_dim,
|
|
|
|
max_sequence_length - 1,
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
|
|
|
|
2022-11-09 10:24:07 -07:00
|
|
|
# head_dim is last for BLOOM
|
|
|
|
if past_values.shape[-1] == head_dim:
|
|
|
|
past_values_head_dim_last = True
|
|
|
|
padded_past_values_shape = (
|
|
|
|
total_batch_size,
|
|
|
|
num_heads,
|
|
|
|
max_sequence_length - 1,
|
|
|
|
head_dim,
|
|
|
|
)
|
|
|
|
elif past_values.shape[-2] == head_dim:
|
|
|
|
past_values_head_dim_last = False
|
|
|
|
padded_past_values_shape = padded_past_keys_shape
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"past_values shape {past_values.shape} is not valid"
|
|
|
|
)
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
# This will run only once per layer
|
2022-11-07 04:53:56 -07:00
|
|
|
if j == len(past_key_values):
|
2022-11-09 10:24:07 -07:00
|
|
|
padded_past_keys = torch.zeros(
|
|
|
|
padded_past_keys_shape,
|
|
|
|
dtype=past_keys.dtype,
|
|
|
|
device=past_keys.device,
|
|
|
|
)
|
|
|
|
padded_past_values = torch.zeros(
|
|
|
|
padded_past_values_shape,
|
|
|
|
dtype=past_values.dtype,
|
|
|
|
device=past_values.device,
|
|
|
|
)
|
|
|
|
past_key_values.append((padded_past_keys, padded_past_values))
|
|
|
|
|
|
|
|
# We slice the past keys and values to remove the padding from previous batches
|
|
|
|
past_key_values[j][0][
|
|
|
|
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
|
|
|
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
|
|
|
|
|
|
|
if past_values_head_dim_last:
|
|
|
|
past_key_values[j][1][
|
|
|
|
start_index:end_index,
|
|
|
|
:,
|
|
|
|
-(batch.max_sequence_length - 1) :,
|
|
|
|
:,
|
|
|
|
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
|
|
|
else:
|
|
|
|
past_key_values[j][1][
|
|
|
|
start_index:end_index,
|
|
|
|
:,
|
|
|
|
:,
|
|
|
|
-(batch.max_sequence_length - 1) :,
|
|
|
|
] = past_values[:, :, :, -(batch.max_sequence_length - 1) :]
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
start_index += batch.size
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
batch_id=batches[0].batch_id,
|
|
|
|
requests=requests,
|
|
|
|
input_ids=input_ids,
|
2022-11-07 04:53:56 -07:00
|
|
|
attention_mask=attention_mask,
|
|
|
|
past_key_values=past_key_values,
|
2022-11-04 11:03:04 -06:00
|
|
|
all_input_ids=all_input_ids,
|
2022-11-07 04:53:56 -07:00
|
|
|
input_lengths=input_lengths,
|
2022-11-04 11:03:04 -06:00
|
|
|
next_token_choosers=next_token_choosers,
|
|
|
|
stopping_criterias=stopping_criterias,
|
|
|
|
size=total_batch_size,
|
|
|
|
max_sequence_length=max_sequence_length,
|
|
|
|
)
|
2022-11-04 07:22:47 -06:00
|
|
|
|
|
|
|
|
|
|
|
class CausalLM(Model):
|
2022-11-07 04:53:56 -07:00
|
|
|
def __init__(self, model_name: str, quantize=False):
|
2022-11-04 07:22:47 -06:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
device = torch.device("cuda")
|
|
|
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
|
|
|
else:
|
|
|
|
device = torch.device("cpu")
|
|
|
|
dtype = torch.float32
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
|
|
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
model_name,
|
|
|
|
torch_dtype=dtype,
|
|
|
|
device_map="auto" if torch.cuda.is_available() else None,
|
2022-11-07 04:53:56 -07:00
|
|
|
load_in_8bit=quantize,
|
2022-11-04 07:22:47 -06:00
|
|
|
).eval()
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
super(CausalLM, self).__init__(
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
num_heads=self.model.config.num_attention_heads,
|
|
|
|
device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def batch_type(self) -> Type[CausalLMBatch]:
|
|
|
|
return CausalLMBatch
|
2022-11-04 07:22:47 -06:00
|
|
|
|
|
|
|
def forward(
|
2022-11-04 11:03:04 -06:00
|
|
|
self, input_ids, attention_mask, past_key_values: Optional = None
|
2022-11-04 07:22:47 -06:00
|
|
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
|
# Model Forward
|
|
|
|
outputs = self.model.forward(
|
|
|
|
input_ids=input_ids,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
use_cache=True,
|
|
|
|
)
|
|
|
|
return outputs.logits, outputs.past_key_values
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
def generate_token(
|
|
|
|
self, batch: CausalLMBatch
|
|
|
|
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
|
|
|
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
|
|
|
context_manager = (
|
|
|
|
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
|
|
|
)
|
|
|
|
with context_manager():
|
2022-11-07 04:53:56 -07:00
|
|
|
logits, past = self.forward(
|
|
|
|
batch.input_ids, batch.attention_mask, batch.past_key_values
|
|
|
|
)
|
2022-11-04 11:03:04 -06:00
|
|
|
|
|
|
|
# List of indices to cache
|
|
|
|
next_batch_keep_indices = []
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# New values for next forward
|
|
|
|
next_batch_input_lengths = []
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_input_ids = []
|
|
|
|
next_batch_all_input_ids = []
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
# Metadata
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_size = 0
|
|
|
|
next_batch_max_sequence_length = 0
|
|
|
|
|
|
|
|
# Finished requests
|
|
|
|
generated_texts: List[GeneratedText] = []
|
|
|
|
|
|
|
|
# Zipped iterator
|
|
|
|
iterator = zip(
|
|
|
|
batch.requests,
|
2022-11-07 04:53:56 -07:00
|
|
|
batch.input_lengths,
|
2022-11-04 11:03:04 -06:00
|
|
|
logits,
|
|
|
|
batch.next_token_choosers,
|
|
|
|
batch.stopping_criterias,
|
|
|
|
batch.all_input_ids,
|
|
|
|
)
|
|
|
|
|
|
|
|
# For each member of the batch
|
|
|
|
for i, (
|
|
|
|
request,
|
|
|
|
input_length,
|
|
|
|
logits,
|
|
|
|
next_token_chooser,
|
|
|
|
stopping_criteria,
|
|
|
|
all_tokens,
|
|
|
|
) in enumerate(iterator):
|
|
|
|
# Select next token
|
|
|
|
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
|
|
|
|
|
|
|
# Append next token to all tokens
|
|
|
|
all_tokens = torch.cat([all_tokens, next_token])
|
|
|
|
|
|
|
|
# Evaluate stopping criteria
|
|
|
|
if stopping_criteria(all_tokens):
|
|
|
|
# Decode all tokens
|
|
|
|
output = self.tokenizer.decode(
|
|
|
|
all_tokens.squeeze(-1), skip_special_tokens=True
|
|
|
|
)
|
|
|
|
# Add to the list of finished generations with the original request
|
|
|
|
generated_texts.append(
|
|
|
|
GeneratedText(request, output, stopping_criteria.current_tokens)
|
|
|
|
)
|
|
|
|
# add to the next batch
|
|
|
|
else:
|
|
|
|
next_batch_keep_indices.append(i)
|
|
|
|
next_batch_input_ids.append(next_token)
|
|
|
|
next_batch_all_input_ids.append(all_tokens)
|
|
|
|
next_batch_size += 1
|
|
|
|
new_input_length = input_length + 1
|
2022-11-07 04:53:56 -07:00
|
|
|
next_batch_input_lengths.append(new_input_length)
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_max_sequence_length = max(
|
|
|
|
next_batch_max_sequence_length, new_input_length
|
|
|
|
)
|
|
|
|
|
|
|
|
# We finished all generations in the batch; there is no next batch
|
|
|
|
if not next_batch_keep_indices:
|
|
|
|
return generated_texts, None
|
|
|
|
|
2022-11-07 04:53:56 -07:00
|
|
|
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
|
|
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
|
|
|
# from the values of the next batch
|
2022-11-04 11:03:04 -06:00
|
|
|
if generated_texts:
|
|
|
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
2022-11-07 04:53:56 -07:00
|
|
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
2022-11-04 11:03:04 -06:00
|
|
|
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
2022-11-07 04:53:56 -07:00
|
|
|
next_batch_past_key_values = [
|
2022-11-04 11:03:04 -06:00
|
|
|
[
|
|
|
|
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
|
|
|
|
for t in layer
|
|
|
|
]
|
|
|
|
for layer in past
|
|
|
|
]
|
|
|
|
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
|
|
|
next_batch_next_token_choosers = [
|
|
|
|
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
|
|
|
]
|
|
|
|
next_batch_stopping_criterias = [
|
|
|
|
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
|
|
|
]
|
|
|
|
else:
|
2022-11-07 04:53:56 -07:00
|
|
|
next_batch_attention_mask = batch.attention_mask
|
|
|
|
next_batch_past_key_values = past
|
2022-11-04 11:03:04 -06:00
|
|
|
next_batch_requests = batch.requests
|
|
|
|
next_batch_next_token_choosers = batch.next_token_choosers
|
|
|
|
next_batch_stopping_criterias = batch.stopping_criterias
|
|
|
|
|
|
|
|
# Update attention_mask with padding as we added a new token to input_ids
|
2022-11-07 04:53:56 -07:00
|
|
|
next_batch_attention_mask = torch.cat(
|
2022-11-04 11:03:04 -06:00
|
|
|
[
|
2022-11-07 04:53:56 -07:00
|
|
|
next_batch_attention_mask,
|
2022-11-04 11:03:04 -06:00
|
|
|
torch.ones((next_batch_size, 1)).to(self.device),
|
|
|
|
],
|
|
|
|
dim=1,
|
|
|
|
)
|
|
|
|
|
|
|
|
next_batch = CausalLMBatch(
|
|
|
|
batch_id=batch.batch_id,
|
|
|
|
requests=next_batch_requests,
|
|
|
|
input_ids=next_batch_input_ids,
|
2022-11-07 04:53:56 -07:00
|
|
|
attention_mask=next_batch_attention_mask,
|
|
|
|
past_key_values=next_batch_past_key_values,
|
2022-11-04 11:03:04 -06:00
|
|
|
all_input_ids=next_batch_all_input_ids,
|
2022-11-07 04:53:56 -07:00
|
|
|
input_lengths=next_batch_input_lengths,
|
2022-11-04 11:03:04 -06:00
|
|
|
next_token_choosers=next_batch_next_token_choosers,
|
|
|
|
stopping_criterias=next_batch_stopping_criterias,
|
|
|
|
size=next_batch_size,
|
|
|
|
max_sequence_length=next_batch_max_sequence_length,
|
|
|
|
)
|
|
|
|
return generated_texts, next_batch
|