feat(server): Support AutoModelForSeq2SeqLM

This commit is contained in:
OlivierDehaene 2022-11-04 18:03:04 +01:00
parent c5665f5c8b
commit 427d7cc444
11 changed files with 892 additions and 376 deletions

View File

@ -15,12 +15,20 @@ A Rust and gRPC server for text generation inference.
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB - 45ms per token generation for BLOOM with 8xA100 80GB
## Supported models ## Officialy supported models
- BLOOM - BLOOM
- BLOOMZ - BLOOMZ
- BLOOM-560m - BLOOM-560m
Other models are supported on a best effort basis using:
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
or
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
## Load Tests for BLOOM ## Load Tests for BLOOM
See `k6/load_test.js` See `k6/load_test.js`
@ -81,7 +89,6 @@ make router-dev
## TODO: ## TODO:
- [ ] Support AutoModelForSeq2SeqLM
- [ ] Add tests for the `server/model` logic - [ ] Add tests for the `server/model` logic
- [ ] Backport custom CUDA kernels to Transformers - [ ] Backport custom CUDA kernels to Transformers
- [ ] Install safetensors with pip - [ ] Install safetensors with pip

View File

@ -54,8 +54,6 @@ message Batch {
repeated Request requests = 2; repeated Request requests = 2;
/// Batch size (==len(requests)) /// Batch size (==len(requests))
uint32 size = 3; uint32 size = 3;
/// Length of the longest sequence within the batch (used for padding)
uint32 max_sequence_length = 4;
} }
message GeneratedText { message GeneratedText {

View File

@ -142,14 +142,10 @@ impl Db {
// Batch size // Batch size
let size = requests.len(); let size = requests.len();
// Longest input length for all requests in batch size
// Used for padding inside the inference server
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
let batch = Batch { let batch = Batch {
id: state.next_batch_id, id: state.next_batch_id,
requests, requests,
size: size as u32, size: size as u32,
max_sequence_length,
}; };
// Update next_batch_start_id to the last id in the batch + 1 // Update next_batch_start_id to the last id in the batch + 1
state.next_batch_start_id = ids.last().unwrap() + 1; state.next_batch_start_id = ids.last().unwrap() + 1;

View File

@ -1,16 +1,18 @@
from typing import Dict, Optional from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch from text_generation.models.types import Batch
B = TypeVar("B", bound=Batch)
class Cache: class Cache:
def __init__(self): def __init__(self):
self.cache: Dict[int, Batch] = {} self.cache: Dict[int, B] = {}
def pop(self, batch_id: int) -> Optional[Batch]: def pop(self, batch_id: int) -> Optional[B]:
return self.cache.pop(batch_id, None) return self.cache.pop(batch_id, None)
def set(self, entry: Batch): def set(self, entry: B):
if entry is not None: if entry is not None:
self.cache[entry.batch_id] = entry self.cache[entry.batch_id] = entry

View File

@ -1,8 +1,9 @@
from text_generation.models.model import Model from text_generation.models.model import Model
from text_generation.models.bloom import BLOOMSharded
from text_generation.models.causal_lm import CausalLM from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM
__all__ = ["Model", "BLOOMSharded", "CausalLM"] __all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
@ -18,4 +19,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
if quantize: if quantize:
raise ValueError("quantize is not supported for AutoModel") raise ValueError("quantize is not supported for AutoModel")
try:
return CausalLM(model_name) return CausalLM(model_name)
except Exception as e:
return Seq2SeqLM(model_name)

View File

@ -12,7 +12,7 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import Model from text_generation.models import CausalLM
from text_generation.utils import ( from text_generation.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -29,7 +29,7 @@ except Exception as e:
torch.manual_seed(0) torch.manual_seed(0)
class BLOOMSharded(Model): class BLOOMSharded(CausalLM):
def __init__(self, model_name: str, quantize: bool = False): def __init__(self, model_name: str, quantize: bool = False):
if not model_name.startswith("bigscience/bloom"): if not model_name.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_name} is not supported")
@ -78,8 +78,11 @@ class BLOOMSharded(Model):
) )
self.model = model.eval().to(dtype) self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BLOOMSharded, self).__init__(tokenizer=tokenizer, num_heads=config.n_head // self.process_group.size(), super(CausalLM, self).__init__(
device=device) tokenizer=tokenizer,
num_heads=config.n_head // self.process_group.size(),
device=device,
)
@staticmethod @staticmethod
def load_weights( def load_weights(

View File

@ -1,9 +1,211 @@
import torch import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional, Tuple, List from typing import Optional, Tuple, List, Dict, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
@dataclass
class CausalLMBatch:
batch_id: int
requests: List[generate_pb2.Request]
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
size: int
max_sequence_length: int
def to_pb(self):
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
)
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
all_input_lengths = []
# Parse batch
for r in pb.requests:
inputs.append(r.inputs)
all_input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
)
)
input_ids = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
batch_id=pb.id,
requests=pb.requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max(all_input_lengths),
)
@classmethod
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length :
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
for j, past in enumerate(batch.input_ids["past_key_values"]):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim, padded_sequence_length = past[0].shape[-2:]
num_heads = (
past[0]
.view(batch.size, -1, head_dim, padded_sequence_length)
.shape[1]
)
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
input_ids["past_key_values"].append([])
# Decoder past
for k, t in enumerate(past):
# Needed because BLOOM past shapes are not the same for keys and values
# Keys: [batch_size * num_heads, head_dim, seq_length]
# Values: [batch_size * num_heads, seq_length, head_dim]
head_dim_last = False
if t.shape[-2] == head_dim:
t = t.view(
batch.size, num_heads, head_dim, padded_sequence_length
)
padded_t_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
)
elif t.shape[-1] == head_dim:
head_dim_last = True
t = t.view(
batch.size, num_heads, padded_sequence_length, head_dim
)
padded_t_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
)
else:
raise ValueError(f"shape {t.shape} is not valid")
# Initialize tensors
# This will run only once per layer and per past tensor
if k == len(input_ids["past_key_values"][j]):
input_ids["past_key_values"][j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
# We slice the past keys and values to remove the padding from previous batches
if not head_dim_last:
input_ids["past_key_values"][j][k][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = t[:, :, :, -(batch.max_sequence_length - 1) :]
else:
input_ids["past_key_values"][j][k][
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
:,
] = t[:, :, -(batch.max_sequence_length - 1) :, :]
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
class CausalLM(Model): class CausalLM(Model):
@ -23,7 +225,15 @@ class CausalLM(Model):
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
).eval() ).eval()
super(CausalLM, self).__init__(tokenizer=tokenizer, num_heads=self.model.config.num_attention_heads, device=device) super(CausalLM, self).__init__(
tokenizer=tokenizer,
num_heads=self.model.config.num_attention_heads,
device=device,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
def forward( def forward(
self, input_ids, attention_mask, past_key_values: Optional = None self, input_ids, attention_mask, past_key_values: Optional = None
@ -36,3 +246,129 @@ class CausalLM(Model):
use_cache=True, use_cache=True,
) )
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
logits, past = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(request, output, stopping_criteria.current_tokens)
)
# add to the next batch
else:
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_input_ids["past_key_values"] = [
[
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
for t in layer
]
for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = CausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch

View File

@ -1,11 +1,13 @@
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional from typing import List, Tuple, Optional, TypeVar, Type
from tokenizers import Tokenizer from tokenizers import Tokenizer
from text_generation.models.types import Batch, GeneratedText from text_generation.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device): def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device):
@ -13,127 +15,11 @@ class Model(ABC):
self.num_heads = num_heads self.num_heads = num_heads
self.device = device self.device = device
@property
@abstractmethod @abstractmethod
def forward(self, input_ids, attention_mask, past_key_values: Optional = None) -> Tuple[torch.Tensor, List[Tuple]]: def batch_type(self) -> Type[B]:
raise NotImplementedError raise NotImplementedError
def generate_token( @abstractmethod
self, batch: Batch def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
) -> Tuple[List[GeneratedText], Optional[Batch]]: raise NotImplementedError
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
logits, past = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(GeneratedText(request, output, stopping_criteria.current_tokens))
# add to the next batch
else:
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_input_ids["past_key_values"] = [
[t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] for t in layer]
for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = Batch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch

View File

@ -0,0 +1,488 @@
import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
@dataclass
class Seq2SeqLMBatch:
batch_id: int
requests: List[generate_pb2.Request]
input_ids: torch.Tensor
attention_mask: torch.Tensor
decoder_input_ids: torch.Tensor
decoder_attention_mask: Optional[torch.Tensor]
encoder_last_hidden_state: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]]
input_lengths: List[int]
decoder_input_lengths: List[int]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
size: int
max_input_length: int
max_decoder_input_length: int
def to_pb(self):
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
)
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Seq2SeqLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
input_lengths = []
decoder_input_ids = []
decoder_input_lengths = []
# Parse batch
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
)
)
tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
return cls(
batch_id=pb.id,
requests=pb.requests,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=None,
encoder_last_hidden_state=None,
past_key_values=None,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=len(pb.requests),
max_input_length=max(input_lengths),
max_decoder_input_length=1,
)
@classmethod
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_input_length = max(batch.max_input_length for batch in batches)
max_decoder_input_length = max(
batch.max_decoder_input_length for batch in batches
)
# Batch attributes
requests = []
input_lengths = []
decoder_input_lengths = []
next_token_choosers = []
stopping_criterias = []
input_ids = None
attention_mask = None
decoder_input_ids = None
decoder_attention_mask = None
encoder_last_hidden_state = None
past_key_values = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.encoder_last_hidden_state is None:
raise ValueError("Batch encoder_last_hidden_state cannot be None")
if input_ids is None:
input_ids = torch.zeros(
(total_batch_size, max_input_length),
dtype=batch.input_ids.dtype,
device=batch.input_ids.device,
)
input_ids[
start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length :]
if attention_mask is None:
attention_mask = torch.zeros(
(total_batch_size, max_input_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
)
attention_mask[
start_index:end_index, -batch.max_input_length :
] = batch.attention_mask[:, -batch.max_input_length :]
if decoder_input_ids is None:
decoder_input_ids = torch.zeros(
(total_batch_size, max_decoder_input_length),
dtype=batch.decoder_input_ids.dtype,
device=batch.decoder_input_ids.device,
)
decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
if decoder_attention_mask is None:
decoder_attention_mask = torch.zeros(
(total_batch_size, max_decoder_input_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
)
if batch.decoder_attention_mask is None:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length :
] = 1
else:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
if encoder_last_hidden_state is None:
encoder_last_hidden_state = torch.zeros(
(
total_batch_size,
max_input_length,
batch.encoder_last_hidden_state.shape[-1],
),
dtype=batch.encoder_last_hidden_state.dtype,
device=batch.encoder_last_hidden_state.device,
)
encoder_last_hidden_state[
start_index:end_index, -batch.max_decoder_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :]
for j, past in enumerate(batch.past_key_values):
_, num_heads, _, head_dim = past[0].shape
# This will run only once per layer
if j == len(past_key_values):
past_key_values.append([])
# Decoder past
for k, t in enumerate(past[:2]):
padded_t_shape = (
total_batch_size,
num_heads,
(max_decoder_input_length - 1),
head_dim,
)
# Initialize tensors
# This will run only once per layer and per past tensor
if k == len(past_key_values[j]):
past_key_values[j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
# We slice the past keys and values to remove the padding from previous batches
past_key_values[j][k][
start_index:end_index,
:,
-(batch.max_decoder_input_length - 1) :,
:,
] = t[:, :, -(batch.max_decoder_input_length - 1) :, :]
# encoder past
for k, t in enumerate(past[2:]):
padded_t_shape = (
total_batch_size,
num_heads,
max_input_length,
head_dim,
)
idx = k + 2
# Initialize tensors
# This will run only once per layer and per past tensor
if idx == len(past_key_values[j]):
past_key_values[j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
past_key_values[j][idx][
start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :]
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
)
class Seq2SeqLM(Model):
def __init__(self, model_name: str):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__(
tokenizer=tokenizer,
num_heads=self.model.config.num_attention_heads,
device=device,
)
@property
def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=[encoder_last_hidden_state]
if encoder_last_hidden_state is not None
else None,
past_key_values=past_key_values,
use_cache=True,
)
return (
outputs.logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.decoder_input_ids,
batch.decoder_attention_mask,
batch.encoder_last_hidden_state,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New input_ids for next forward
next_batch_input_lengths = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
next_batch_size = 0
next_batch_max_input_length = 0
next_batch_max_decoder_input_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.decoder_input_lengths,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.input_ids,
batch.decoder_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
decoder_input_length,
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_tokens,
) in enumerate(iterator):
all_tokens = torch.cat([input_tokens, decoder_tokens])
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to decoder tokens
decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)])
# Evaluate stopping criteria
if stopping_criteria(decoder_tokens):
# Decode all tokens
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
# Add to the list of finished generations with the original request
generated_texts.append(
GeneratedText(request, output, stopping_criteria.current_tokens)
)
# add to the next batch
else:
next_batch_keep_indices.append(i)
next_batch_decoder_input_ids.append(decoder_tokens.unsqueeze(0))
next_batch_size += 1
new_decoder_input_length = decoder_input_length + 1
next_batch_input_lengths.append(input_length)
next_batch_decoder_input_lengths.append(new_decoder_input_length)
next_batch_max_input_length = max(
next_batch_max_input_length, input_length
)
next_batch_max_decoder_input_length = max(
next_batch_max_decoder_input_length, new_decoder_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
if generated_texts:
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
next_batch_keep_indices
]
else:
next_batch_decoder_attention_mask = None
next_batch_encoder_last_hidden_state = encoder_last_hidden_state[
next_batch_keep_indices
]
next_batch_past_key_values = [
[t[next_batch_keep_indices] for t in layer] for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids = batch.input_ids
next_batch_attention_mask = batch.attention_mask
next_batch_decoder_attention_mask = batch.decoder_attention_mask
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None:
next_batch_decoder_attention_mask = torch.cat(
[
next_batch_decoder_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = Seq2SeqLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
attention_mask=next_batch_attention_mask,
decoder_input_ids=next_batch_decoder_input_ids,
decoder_attention_mask=next_batch_decoder_attention_mask,
encoder_last_hidden_state=next_batch_encoder_last_hidden_state,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
decoder_input_lengths=next_batch_decoder_input_lengths,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length,
)
return generated_texts, next_batch

View File

@ -1,237 +1,30 @@
import torch import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict from typing import List
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
@dataclass class Batch(ABC):
class Batch: @abstractmethod
batch_id: int def to_pb(self) -> generate_pb2.Batch:
requests: List[generate_pb2.Request] raise NotImplementedError
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
size: int
max_sequence_length: int
def to_pb(self):
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
max_sequence_length=self.max_sequence_length,
)
@classmethod @classmethod
@abstractmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Batch": ) -> "Batch":
inputs = [] raise NotImplementedError
next_token_choosers = []
stopping_criterias = []
all_input_lengths = []
# Parse batch
for r in pb.requests:
inputs.append(r.inputs)
all_input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
)
)
input_ids = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
batch_id=pb.id,
requests=pb.requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=pb.max_sequence_length,
)
@classmethod @classmethod
@abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
# Used for padding raise NotImplementedError
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Only needed for Seq2SeqLM
max_encoded_sequence_length = None
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length:
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:]
for j, past in enumerate(batch.input_ids["past_key_values"]):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim, padded_sequence_length = past[0].shape[-2:]
num_heads = (
past[0]
.view(batch.size, -1, head_dim, padded_sequence_length)
.shape[1]
)
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
input_ids["past_key_values"].append([])
# Decoder past
for k, t in enumerate(past[:2]):
# Needed because BLOOM past shapes are not the same for keys and values
# Keys: [batch_size * num_heads, head_dim, seq_length]
# Values: [batch_size * num_heads, seq_length, head_dim]
head_dim_last = False
if t.shape[-2] == head_dim:
t = t.view(
batch.size, num_heads, head_dim, padded_sequence_length
)
padded_t_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
)
elif t.shape[-1] == head_dim:
head_dim_last = True
t = t.view(
batch.size, num_heads, padded_sequence_length, head_dim
)
padded_t_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
)
else:
raise ValueError(f"shape {t.shape} is not valid")
# Initialize tensors
# This will run only once per layer and per past tensor
if k == len(input_ids["past_key_values"][j]):
input_ids["past_key_values"][j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
# We slice the past keys and values to remove the padding from previous batches
if not head_dim_last:
input_ids["past_key_values"][j][k][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1):,
] = t[:, :, :, -(batch.max_sequence_length - 1):]
else:
input_ids["past_key_values"][j][k][
start_index:end_index,
:,
-(batch.max_sequence_length - 1):,
:,
] = t[:, :, -(batch.max_sequence_length - 1):, :]
# Seq2SeqLM specific past (encoder past)
for k, t in enumerate(past[2:]):
if max_encoded_sequence_length is None:
max_encoded_sequence_length = max(max(batch.all_input_lengths) for batch in batches)
batch_max_encoded_sequence_length = max(batch.all_input_lengths)
padded_t_shape = (total_batch_size, num_heads, max_encoded_sequence_length, head_dim)
idx = k + 2
# Initialize tensors
# This will run only once per layer and per past tensor
if idx == len(input_ids["past_key_values"][j]):
input_ids["past_key_values"][j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
input_ids["past_key_values"][j][idx][
start_index:end_index,
:,
-batch_max_encoded_sequence_length:,
:
] = t[:, :, -batch_max_encoded_sequence_length:, :]
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
@dataclass @dataclass
@ -241,4 +34,6 @@ class GeneratedText:
tokens: int tokens: int
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(request=self.request, output=self.output, tokens=self.tokens) return generate_pb2.GeneratedText(
request=self.request, output=self.output, tokens=self.tokens
)

View File

@ -9,7 +9,6 @@ from typing import List
from text_generation.cache import Cache from text_generation.cache import Cache
from text_generation.models import Model, get_model from text_generation.models import Model, get_model
from text_generation.models.types import Batch
from text_generation.pb import generate_pb2_grpc, generate_pb2 from text_generation.pb import generate_pb2_grpc, generate_pb2
@ -27,7 +26,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def Generate(self, request, context): async def Generate(self, request, context):
batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device) batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
)
generated_texts, next_batch = self.model.generate_token(batch) generated_texts, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
@ -51,7 +52,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batches.append(batch) batches.append(batch)
if len(batches) > 1: if len(batches) > 1:
batch = Batch.concatenate(batches) batch = self.model.batch_type.concatenate(batches)
else: else:
batch = batches[0] batch = batches[0]