From 427d7cc44459b5d06421300c8d74a07e1db1773b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 4 Nov 2022 18:03:04 +0100 Subject: [PATCH] feat(server): Support AutoModelForSeq2SeqLM --- README.md | 11 +- proto/generate.proto | 2 - router/src/db.rs | 4 - server/text_generation/cache.py | 10 +- server/text_generation/models/__init__.py | 10 +- server/text_generation/models/bloom.py | 31 +- server/text_generation/models/causal_lm.py | 342 +++++++++++++- server/text_generation/models/model.py | 130 +----- server/text_generation/models/seq2seq_lm.py | 488 ++++++++++++++++++++ server/text_generation/models/types.py | 233 +--------- server/text_generation/server.py | 7 +- 11 files changed, 892 insertions(+), 376 deletions(-) create mode 100644 server/text_generation/models/seq2seq_lm.py diff --git a/README.md b/README.md index 29e5835..eadf5e5 100644 --- a/README.md +++ b/README.md @@ -15,12 +15,20 @@ A Rust and gRPC server for text generation inference. - [Safetensors](https://github.com/huggingface/safetensors) weight loading - 45ms per token generation for BLOOM with 8xA100 80GB -## Supported models +## Officialy supported models - BLOOM - BLOOMZ - BLOOM-560m +Other models are supported on a best effort basis using: + +`AutoModelForCausalLM.from_pretrained(, device_map="auto")` + +or + +`AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")` + ## Load Tests for BLOOM See `k6/load_test.js` @@ -81,7 +89,6 @@ make router-dev ## TODO: -- [ ] Support AutoModelForSeq2SeqLM - [ ] Add tests for the `server/model` logic - [ ] Backport custom CUDA kernels to Transformers - [ ] Install safetensors with pip \ No newline at end of file diff --git a/proto/generate.proto b/proto/generate.proto index 68dfa15..14f6f66 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -54,8 +54,6 @@ message Batch { repeated Request requests = 2; /// Batch size (==len(requests)) uint32 size = 3; - /// Length of the longest sequence within the batch (used for padding) - uint32 max_sequence_length = 4; } message GeneratedText { diff --git a/router/src/db.rs b/router/src/db.rs index af36614..0701206 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -142,14 +142,10 @@ impl Db { // Batch size let size = requests.len(); - // Longest input length for all requests in batch size - // Used for padding inside the inference server - let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap(); let batch = Batch { id: state.next_batch_id, requests, size: size as u32, - max_sequence_length, }; // Update next_batch_start_id to the last id in the batch + 1 state.next_batch_start_id = ids.last().unwrap() + 1; diff --git a/server/text_generation/cache.py b/server/text_generation/cache.py index 65ec3e7..5a3a8d3 100644 --- a/server/text_generation/cache.py +++ b/server/text_generation/cache.py @@ -1,16 +1,18 @@ -from typing import Dict, Optional +from typing import Dict, Optional, TypeVar from text_generation.models.types import Batch +B = TypeVar("B", bound=Batch) + class Cache: def __init__(self): - self.cache: Dict[int, Batch] = {} + self.cache: Dict[int, B] = {} - def pop(self, batch_id: int) -> Optional[Batch]: + def pop(self, batch_id: int) -> Optional[B]: return self.cache.pop(batch_id, None) - def set(self, entry: Batch): + def set(self, entry: B): if entry is not None: self.cache[entry.batch_id] = entry diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 0371ab4..ade22d4 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -1,8 +1,9 @@ from text_generation.models.model import Model -from text_generation.models.bloom import BLOOMSharded from text_generation.models.causal_lm import CausalLM +from text_generation.models.bloom import BLOOMSharded +from text_generation.models.seq2seq_lm import Seq2SeqLM -__all__ = ["Model", "BLOOMSharded", "CausalLM"] +__all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"] def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: @@ -18,4 +19,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: raise ValueError("sharded is not supported for AutoModel") if quantize: raise ValueError("quantize is not supported for AutoModel") - return CausalLM(model_name) + try: + return CausalLM(model_name) + except Exception as e: + return Seq2SeqLM(model_name) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 37f55f2..730958c 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -12,7 +12,7 @@ from transformers.models.bloom.parallel_layers import ( TensorParallelRowLinear, ) -from text_generation.models import Model +from text_generation.models import CausalLM from text_generation.utils import ( initialize_torch_distributed, weight_files, @@ -29,7 +29,7 @@ except Exception as e: torch.manual_seed(0) -class BLOOMSharded(Model): +class BLOOMSharded(CausalLM): def __init__(self, model_name: str, quantize: bool = False): if not model_name.startswith("bigscience/bloom"): raise ValueError(f"Model {model_name} is not supported") @@ -78,22 +78,25 @@ class BLOOMSharded(Model): ) self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) - super(BLOOMSharded, self).__init__(tokenizer=tokenizer, num_heads=config.n_head // self.process_group.size(), - device=device) + super(CausalLM, self).__init__( + tokenizer=tokenizer, + num_heads=config.n_head // self.process_group.size(), + device=device, + ) @staticmethod def load_weights( - model, - filenames: List[str], - quantize: bool, - device: torch.device, - rank: int, - world_size: int, + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): full_name = f"transformer.{name}" @@ -156,9 +159,9 @@ class BLOOMSharded(Model): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor.transpose(1, 0), diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 39bb450..b07537b 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -1,9 +1,211 @@ import torch +from dataclasses import dataclass from transformers import AutoTokenizer, AutoModelForCausalLM -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Dict, Type from text_generation.models import Model +from text_generation.models.types import GeneratedText +from text_generation.pb import generate_pb2 +from text_generation.utils import NextTokenChooser, StoppingCriteria + + +@dataclass +class CausalLMBatch: + batch_id: int + requests: List[generate_pb2.Request] + all_input_lengths: List[int] + input_ids: Dict[str, torch.Tensor] + all_input_ids: List[torch.Tensor] + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + size: int + max_sequence_length: int + + def to_pb(self): + return generate_pb2.Batch( + id=self.batch_id, + requests=self.requests, + size=self.size, + ) + + @classmethod + def from_pb( + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + ) -> "CausalLMBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + all_input_lengths = [] + + # Parse batch + for r in pb.requests: + inputs.append(r.inputs) + all_input_lengths.append(r.input_length) + next_token_choosers.append( + NextTokenChooser( + temperature=r.parameters.temperature, + top_k=r.parameters.top_k, + top_p=r.parameters.top_p, + do_sample=r.parameters.do_sample, + ) + ) + stopping_criterias.append( + StoppingCriteria( + eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens + ) + ) + + input_ids = tokenizer( + inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + ).to(device) + all_input_ids = input_ids["input_ids"].unsqueeze(-1) + + return cls( + batch_id=pb.id, + requests=pb.requests, + all_input_lengths=all_input_lengths, + input_ids=input_ids, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=pb.size, + max_sequence_length=max(all_input_lengths), + ) + + @classmethod + def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": + # Used for padding + total_batch_size = sum(batch.size for batch in batches) + max_sequence_length = max(batch.max_sequence_length for batch in batches) + + # Batch attributes + input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} + requests = [] + all_input_lengths = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + all_input_lengths.extend(batch.all_input_lengths) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Slicing end index for this batch + end_index = start_index + batch.size + + # We only concatenate batches that did at least one step + if batch.input_ids["input_ids"].shape[1] > 1: + raise ValueError("Batch input_ids should be of shape (batch_size, 1)") + + # Initialize tensors + if i == 0: + input_ids["input_ids"] = torch.empty( + (total_batch_size, 1), + dtype=batch.input_ids["input_ids"].dtype, + device=batch.input_ids["input_ids"].device, + ) + input_ids["attention_mask"] = torch.zeros( + (total_batch_size, max_sequence_length), + dtype=batch.input_ids["attention_mask"].dtype, + device=batch.input_ids["attention_mask"].device, + ) + + # input_ids["input_ids"] is always of shape [batch_size, 1] + # We do not need to pad it + input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] + + # We need to slice the attention mask to remove padding from previous steps + input_ids["attention_mask"][ + start_index:end_index, -batch.max_sequence_length : + ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] + + for j, past in enumerate(batch.input_ids["past_key_values"]): + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...] + head_dim, padded_sequence_length = past[0].shape[-2:] + num_heads = ( + past[0] + .view(batch.size, -1, head_dim, padded_sequence_length) + .shape[1] + ) + + # This will run only once per layer + if j == len(input_ids["past_key_values"]): + input_ids["past_key_values"].append([]) + + # Decoder past + for k, t in enumerate(past): + # Needed because BLOOM past shapes are not the same for keys and values + # Keys: [batch_size * num_heads, head_dim, seq_length] + # Values: [batch_size * num_heads, seq_length, head_dim] + head_dim_last = False + if t.shape[-2] == head_dim: + t = t.view( + batch.size, num_heads, head_dim, padded_sequence_length + ) + padded_t_shape = ( + total_batch_size, + num_heads, + head_dim, + max_sequence_length - 1, + ) + elif t.shape[-1] == head_dim: + head_dim_last = True + t = t.view( + batch.size, num_heads, padded_sequence_length, head_dim + ) + padded_t_shape = ( + total_batch_size, + num_heads, + max_sequence_length - 1, + head_dim, + ) + else: + raise ValueError(f"shape {t.shape} is not valid") + + # Initialize tensors + # This will run only once per layer and per past tensor + if k == len(input_ids["past_key_values"][j]): + input_ids["past_key_values"][j].append( + torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) + ) + + # We slice the past keys and values to remove the padding from previous batches + if not head_dim_last: + input_ids["past_key_values"][j][k][ + start_index:end_index, + :, + :, + -(batch.max_sequence_length - 1) :, + ] = t[:, :, :, -(batch.max_sequence_length - 1) :] + else: + input_ids["past_key_values"][j][k][ + start_index:end_index, + :, + -(batch.max_sequence_length - 1) :, + :, + ] = t[:, :, -(batch.max_sequence_length - 1) :, :] + + start_index += batch.size + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + all_input_lengths=all_input_lengths, + input_ids=input_ids, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=total_batch_size, + max_sequence_length=max_sequence_length, + ) class CausalLM(Model): @@ -23,10 +225,18 @@ class CausalLM(Model): device_map="auto" if torch.cuda.is_available() else None, ).eval() - super(CausalLM, self).__init__(tokenizer=tokenizer, num_heads=self.model.config.num_attention_heads, device=device) + super(CausalLM, self).__init__( + tokenizer=tokenizer, + num_heads=self.model.config.num_attention_heads, + device=device, + ) + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return CausalLMBatch def forward( - self, input_ids, attention_mask, past_key_values: Optional = None + self, input_ids, attention_mask, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward outputs = self.model.forward( @@ -36,3 +246,129 @@ class CausalLM(Model): use_cache=True, ) return outputs.logits, outputs.past_key_values + + def generate_token( + self, batch: CausalLMBatch + ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: + # For some reason, inference_mode does not work well with GLOO which we use on CPU + context_manager = ( + torch.no_grad if self.device.type == "cpu" else torch.inference_mode + ) + with context_manager(): + logits, past = self.forward(**batch.input_ids) + + # List of indices to cache + next_batch_keep_indices = [] + + # New input_ids for next forward + next_batch_input_ids = [] + next_batch_all_input_ids = [] + next_all_input_lengths = [] + + next_batch_size = 0 + next_batch_max_sequence_length = 0 + + # Finished requests + generated_texts: List[GeneratedText] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.all_input_lengths, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_tokens, + ) in enumerate(iterator): + # Select next token + next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + + # Append next token to all tokens + all_tokens = torch.cat([all_tokens, next_token]) + + # Evaluate stopping criteria + if stopping_criteria(all_tokens): + # Decode all tokens + output = self.tokenizer.decode( + all_tokens.squeeze(-1), skip_special_tokens=True + ) + # Add to the list of finished generations with the original request + generated_texts.append( + GeneratedText(request, output, stopping_criteria.current_tokens) + ) + # add to the next batch + else: + next_batch_keep_indices.append(i) + next_batch_input_ids.append(next_token) + next_batch_all_input_ids.append(all_tokens) + next_batch_size += 1 + new_input_length = input_length + 1 + next_all_input_lengths.append(new_input_length) + next_batch_max_sequence_length = max( + next_batch_max_sequence_length, new_input_length + ) + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generated_texts, None + + # If we finished at least one generation + next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} + if generated_texts: + # Apply indices to attention mask, past key values and other items that need to be cached + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ + next_batch_keep_indices + ] + # Force past to be of dim [batch_size, num_heads, ...] for easy indexing + next_batch_input_ids["past_key_values"] = [ + [ + t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] + for t in layer + ] + for layer in past + ] + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] + next_batch_input_ids["past_key_values"] = past + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Update attention_mask with padding as we added a new token to input_ids + next_batch_input_ids["attention_mask"] = torch.cat( + [ + next_batch_input_ids["attention_mask"], + torch.ones((next_batch_size, 1)).to(self.device), + ], + dim=1, + ) + + next_batch = CausalLMBatch( + batch_id=batch.batch_id, + requests=next_batch_requests, + all_input_lengths=next_all_input_lengths, + input_ids=next_batch_input_ids, + all_input_ids=next_batch_all_input_ids, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + size=next_batch_size, + max_sequence_length=next_batch_max_sequence_length, + ) + return generated_texts, next_batch diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index d6b3c73..7fb8142 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -1,11 +1,13 @@ import torch from abc import ABC, abstractmethod -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, TypeVar, Type from tokenizers import Tokenizer from text_generation.models.types import Batch, GeneratedText +B = TypeVar("B", bound=Batch) + class Model(ABC): def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device): @@ -13,127 +15,11 @@ class Model(ABC): self.num_heads = num_heads self.device = device + @property @abstractmethod - def forward(self, input_ids, attention_mask, past_key_values: Optional = None) -> Tuple[torch.Tensor, List[Tuple]]: + def batch_type(self) -> Type[B]: raise NotImplementedError - def generate_token( - self, batch: Batch - ) -> Tuple[List[GeneratedText], Optional[Batch]]: - # For some reason, inference_mode does not work well with GLOO which we use on CPU - context_manager = ( - torch.no_grad if self.device.type == "cpu" else torch.inference_mode - ) - with context_manager(): - logits, past = self.forward(**batch.input_ids) - - # List of indices to cache - next_batch_keep_indices = [] - - # New input_ids for next forward - next_batch_input_ids = [] - next_batch_all_input_ids = [] - next_all_input_lengths = [] - - next_batch_size = 0 - next_batch_max_sequence_length = 0 - - # Finished requests - generated_texts: List[GeneratedText] = [] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.all_input_lengths, - logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_tokens, - ) in enumerate(iterator): - # Select next token - next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) - - # Append next token to all tokens - all_tokens = torch.cat([all_tokens, next_token]) - - # Evaluate stopping criteria - if stopping_criteria(all_tokens): - # Decode all tokens - output = self.tokenizer.decode( - all_tokens.squeeze(-1), skip_special_tokens=True - ) - # Add to the list of finished generations with the original request - generated_texts.append(GeneratedText(request, output, stopping_criteria.current_tokens)) - # add to the next batch - else: - next_batch_keep_indices.append(i) - next_batch_input_ids.append(next_token) - next_batch_all_input_ids.append(all_tokens) - next_batch_size += 1 - new_input_length = input_length + 1 - next_all_input_lengths.append(new_input_length) - next_batch_max_sequence_length = max( - next_batch_max_sequence_length, new_input_length - ) - - # We finished all generations in the batch; there is no next batch - if not next_batch_keep_indices: - return generated_texts, None - - # If we finished at least one generation - next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} - if generated_texts: - # Apply indices to attention mask, past key values and other items that need to be cached - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ - next_batch_keep_indices - ] - # Force past to be of dim [batch_size, num_heads, ...] for easy indexing - next_batch_input_ids["past_key_values"] = [ - [t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] for t in layer] - for layer in past - ] - next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] - next_batch_next_token_choosers = [ - batch.next_token_choosers[i] for i in next_batch_keep_indices - ] - next_batch_stopping_criterias = [ - batch.stopping_criterias[i] for i in next_batch_keep_indices - ] - else: - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] - next_batch_input_ids["past_key_values"] = past - next_batch_requests = batch.requests - next_batch_next_token_choosers = batch.next_token_choosers - next_batch_stopping_criterias = batch.stopping_criterias - - # Update attention_mask with padding as we added a new token to input_ids - next_batch_input_ids["attention_mask"] = torch.cat( - [ - next_batch_input_ids["attention_mask"], - torch.ones((next_batch_size, 1)).to(self.device), - ], - dim=1, - ) - - next_batch = Batch( - batch_id=batch.batch_id, - requests=next_batch_requests, - all_input_lengths=next_all_input_lengths, - input_ids=next_batch_input_ids, - all_input_ids=next_batch_all_input_ids, - next_token_choosers=next_batch_next_token_choosers, - stopping_criterias=next_batch_stopping_criterias, - size=next_batch_size, - max_sequence_length=next_batch_max_sequence_length, - ) - return generated_texts, next_batch + @abstractmethod + def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: + raise NotImplementedError diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py new file mode 100644 index 0000000..0607b3d --- /dev/null +++ b/server/text_generation/models/seq2seq_lm.py @@ -0,0 +1,488 @@ +import torch + +from dataclasses import dataclass +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +from typing import Optional, Tuple, List, Type + +from text_generation.models import Model +from text_generation.models.types import GeneratedText +from text_generation.pb import generate_pb2 +from text_generation.utils import NextTokenChooser, StoppingCriteria + + +@dataclass +class Seq2SeqLMBatch: + batch_id: int + requests: List[generate_pb2.Request] + + input_ids: torch.Tensor + attention_mask: torch.Tensor + + decoder_input_ids: torch.Tensor + decoder_attention_mask: Optional[torch.Tensor] + encoder_last_hidden_state: Optional[torch.Tensor] + + past_key_values: Optional[List[Tuple]] + + input_lengths: List[int] + decoder_input_lengths: List[int] + + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + size: int + max_input_length: int + max_decoder_input_length: int + + def to_pb(self): + return generate_pb2.Batch( + id=self.batch_id, + requests=self.requests, + size=self.size, + ) + + @classmethod + def from_pb( + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + ) -> "Seq2SeqLMBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + input_lengths = [] + + decoder_input_ids = [] + decoder_input_lengths = [] + + # Parse batch + for r in pb.requests: + inputs.append(r.inputs) + input_lengths.append(r.input_length) + decoder_input_ids.append(tokenizer.bos_token_id) + decoder_input_lengths.append(1) + next_token_choosers.append( + NextTokenChooser( + temperature=r.parameters.temperature, + top_k=r.parameters.top_k, + top_p=r.parameters.top_p, + do_sample=r.parameters.do_sample, + ) + ) + stopping_criterias.append( + StoppingCriteria( + eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens + ) + ) + + tokenized_inputs = tokenizer( + inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + ).to(device) + decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1) + + return cls( + batch_id=pb.id, + requests=pb.requests, + input_ids=tokenized_inputs["input_ids"], + attention_mask=tokenized_inputs["attention_mask"], + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=None, + encoder_last_hidden_state=None, + past_key_values=None, + input_lengths=input_lengths, + decoder_input_lengths=decoder_input_lengths, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=len(pb.requests), + max_input_length=max(input_lengths), + max_decoder_input_length=1, + ) + + @classmethod + def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": + # Used for padding + total_batch_size = sum(batch.size for batch in batches) + max_input_length = max(batch.max_input_length for batch in batches) + max_decoder_input_length = max( + batch.max_decoder_input_length for batch in batches + ) + + # Batch attributes + requests = [] + input_lengths = [] + decoder_input_lengths = [] + next_token_choosers = [] + stopping_criterias = [] + + input_ids = None + attention_mask = None + decoder_input_ids = None + decoder_attention_mask = None + encoder_last_hidden_state = None + past_key_values = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + decoder_input_lengths.extend(batch.decoder_input_lengths) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Slicing end index for this batch + end_index = start_index + batch.size + + # We only concatenate batches that did at least one step + if batch.encoder_last_hidden_state is None: + raise ValueError("Batch encoder_last_hidden_state cannot be None") + + if input_ids is None: + input_ids = torch.zeros( + (total_batch_size, max_input_length), + dtype=batch.input_ids.dtype, + device=batch.input_ids.device, + ) + input_ids[ + start_index:end_index, -batch.max_input_length : + ] = batch.input_ids[:, -batch.max_input_length :] + + if attention_mask is None: + attention_mask = torch.zeros( + (total_batch_size, max_input_length), + dtype=batch.attention_mask.dtype, + device=batch.attention_mask.device, + ) + attention_mask[ + start_index:end_index, -batch.max_input_length : + ] = batch.attention_mask[:, -batch.max_input_length :] + + if decoder_input_ids is None: + decoder_input_ids = torch.zeros( + (total_batch_size, max_decoder_input_length), + dtype=batch.decoder_input_ids.dtype, + device=batch.decoder_input_ids.device, + ) + decoder_input_ids[ + start_index:end_index, -batch.max_decoder_input_length : + ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :] + + if decoder_attention_mask is None: + decoder_attention_mask = torch.zeros( + (total_batch_size, max_decoder_input_length), + dtype=batch.attention_mask.dtype, + device=batch.attention_mask.device, + ) + if batch.decoder_attention_mask is None: + decoder_attention_mask[ + start_index:end_index, -batch.max_decoder_input_length : + ] = 1 + else: + decoder_attention_mask[ + start_index:end_index, -batch.max_decoder_input_length : + ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :] + + if encoder_last_hidden_state is None: + encoder_last_hidden_state = torch.zeros( + ( + total_batch_size, + max_input_length, + batch.encoder_last_hidden_state.shape[-1], + ), + dtype=batch.encoder_last_hidden_state.dtype, + device=batch.encoder_last_hidden_state.device, + ) + + encoder_last_hidden_state[ + start_index:end_index, -batch.max_decoder_input_length :, : + ] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :] + + for j, past in enumerate(batch.past_key_values): + _, num_heads, _, head_dim = past[0].shape + + # This will run only once per layer + if j == len(past_key_values): + past_key_values.append([]) + + # Decoder past + for k, t in enumerate(past[:2]): + padded_t_shape = ( + total_batch_size, + num_heads, + (max_decoder_input_length - 1), + head_dim, + ) + + # Initialize tensors + # This will run only once per layer and per past tensor + if k == len(past_key_values[j]): + past_key_values[j].append( + torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) + ) + + # We slice the past keys and values to remove the padding from previous batches + past_key_values[j][k][ + start_index:end_index, + :, + -(batch.max_decoder_input_length - 1) :, + :, + ] = t[:, :, -(batch.max_decoder_input_length - 1) :, :] + + # encoder past + for k, t in enumerate(past[2:]): + padded_t_shape = ( + total_batch_size, + num_heads, + max_input_length, + head_dim, + ) + + idx = k + 2 + + # Initialize tensors + # This will run only once per layer and per past tensor + if idx == len(past_key_values[j]): + past_key_values[j].append( + torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) + ) + + past_key_values[j][idx][ + start_index:end_index, :, -batch.max_input_length :, : + ] = t[:, :, -batch.max_input_length :, :] + + start_index += batch.size + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_last_hidden_state=encoder_last_hidden_state, + past_key_values=past_key_values, + input_lengths=input_lengths, + decoder_input_lengths=decoder_input_lengths, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=total_batch_size, + max_input_length=max_input_length, + max_decoder_input_length=max_decoder_input_length, + ) + + +class Seq2SeqLM(Model): + def __init__(self, model_name: str): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + else: + device = torch.device("cpu") + dtype = torch.float32 + + self.model = AutoModelForSeq2SeqLM.from_pretrained( + model_name, + torch_dtype=dtype, + device_map="auto" if torch.cuda.is_available() else None, + ).eval() + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer.bos_token_id = self.model.config.decoder_start_token_id + + super(Seq2SeqLM, self).__init__( + tokenizer=tokenizer, + num_heads=self.model.config.num_attention_heads, + device=device, + ) + + @property + def batch_type(self) -> Type[Seq2SeqLMBatch]: + return Seq2SeqLMBatch + + def forward( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask: Optional, + encoder_last_hidden_state: Optional, + past_key_values: Optional = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + ]: + # Model Forward + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1) + + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=[encoder_last_hidden_state] + if encoder_last_hidden_state is not None + else None, + past_key_values=past_key_values, + use_cache=True, + ) + return ( + outputs.logits, + outputs.encoder_last_hidden_state, + outputs.past_key_values, + ) + + def generate_token( + self, batch: Seq2SeqLMBatch + ) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]: + # For some reason, inference_mode does not work well with GLOO which we use on CPU + context_manager = ( + torch.no_grad if self.device.type == "cpu" else torch.inference_mode + ) + with context_manager(): + logits, encoder_last_hidden_state, past = self.forward( + batch.input_ids, + batch.attention_mask, + batch.decoder_input_ids, + batch.decoder_attention_mask, + batch.encoder_last_hidden_state, + batch.past_key_values, + ) + + # List of indices to cache + next_batch_keep_indices = [] + + # New input_ids for next forward + next_batch_input_lengths = [] + next_batch_decoder_input_ids = [] + next_batch_decoder_input_lengths = [] + + next_batch_size = 0 + next_batch_max_input_length = 0 + next_batch_max_decoder_input_length = 0 + + # Finished requests + generated_texts: List[GeneratedText] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.decoder_input_lengths, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.input_ids, + batch.decoder_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + decoder_input_length, + logits, + next_token_chooser, + stopping_criteria, + input_tokens, + decoder_tokens, + ) in enumerate(iterator): + all_tokens = torch.cat([input_tokens, decoder_tokens]) + # Select next token + next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + + # Append next token to decoder tokens + decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)]) + + # Evaluate stopping criteria + if stopping_criteria(decoder_tokens): + # Decode all tokens + output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True) + # Add to the list of finished generations with the original request + generated_texts.append( + GeneratedText(request, output, stopping_criteria.current_tokens) + ) + # add to the next batch + else: + next_batch_keep_indices.append(i) + next_batch_decoder_input_ids.append(decoder_tokens.unsqueeze(0)) + next_batch_size += 1 + new_decoder_input_length = decoder_input_length + 1 + next_batch_input_lengths.append(input_length) + next_batch_decoder_input_lengths.append(new_decoder_input_length) + next_batch_max_input_length = max( + next_batch_max_input_length, input_length + ) + next_batch_max_decoder_input_length = max( + next_batch_max_decoder_input_length, new_decoder_input_length + ) + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generated_texts, None + + # If we finished at least one generation + next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) + if generated_texts: + next_batch_input_ids = batch.input_ids[next_batch_keep_indices] + next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] + + if batch.decoder_attention_mask is not None: + next_batch_decoder_attention_mask = batch.decoder_attention_mask[ + next_batch_keep_indices + ] + else: + next_batch_decoder_attention_mask = None + + next_batch_encoder_last_hidden_state = encoder_last_hidden_state[ + next_batch_keep_indices + ] + + next_batch_past_key_values = [ + [t[next_batch_keep_indices] for t in layer] for layer in past + ] + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_input_ids = batch.input_ids + next_batch_attention_mask = batch.attention_mask + next_batch_decoder_attention_mask = batch.decoder_attention_mask + next_batch_encoder_last_hidden_state = encoder_last_hidden_state + next_batch_past_key_values = past + + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Update attention_mask with padding as we added a new token to input_ids + if next_batch_decoder_attention_mask is not None: + next_batch_decoder_attention_mask = torch.cat( + [ + next_batch_decoder_attention_mask, + torch.ones((next_batch_size, 1)).to(self.device), + ], + dim=1, + ) + + next_batch = Seq2SeqLMBatch( + batch_id=batch.batch_id, + requests=next_batch_requests, + input_ids=next_batch_input_ids, + attention_mask=next_batch_attention_mask, + decoder_input_ids=next_batch_decoder_input_ids, + decoder_attention_mask=next_batch_decoder_attention_mask, + encoder_last_hidden_state=next_batch_encoder_last_hidden_state, + past_key_values=next_batch_past_key_values, + input_lengths=next_batch_input_lengths, + decoder_input_lengths=next_batch_decoder_input_lengths, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + size=next_batch_size, + max_input_length=next_batch_max_input_length, + max_decoder_input_length=next_batch_max_decoder_input_length, + ) + return generated_texts, next_batch diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 0c17e55..7c25bf6 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -1,237 +1,30 @@ import torch +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Dict +from typing import List from transformers import AutoTokenizer from text_generation.pb import generate_pb2 -from text_generation.utils import NextTokenChooser, StoppingCriteria -@dataclass -class Batch: - batch_id: int - requests: List[generate_pb2.Request] - all_input_lengths: List[int] - input_ids: Dict[str, torch.Tensor] - all_input_ids: List[torch.Tensor] - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - size: int - max_sequence_length: int - - def to_pb(self): - return generate_pb2.Batch( - id=self.batch_id, - requests=self.requests, - size=self.size, - max_sequence_length=self.max_sequence_length, - ) +class Batch(ABC): + @abstractmethod + def to_pb(self) -> generate_pb2.Batch: + raise NotImplementedError @classmethod + @abstractmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "Batch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - all_input_lengths = [] - - # Parse batch - for r in pb.requests: - inputs.append(r.inputs) - all_input_lengths.append(r.input_length) - next_token_choosers.append( - NextTokenChooser( - temperature=r.parameters.temperature, - top_k=r.parameters.top_k, - top_p=r.parameters.top_p, - do_sample=r.parameters.do_sample, - ) - ) - stopping_criterias.append( - StoppingCriteria( - eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens - ) - ) - - input_ids = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 - ).to(device) - all_input_ids = input_ids["input_ids"].unsqueeze(-1) - - return cls( - batch_id=pb.id, - requests=pb.requests, - all_input_lengths=all_input_lengths, - input_ids=input_ids, - all_input_ids=all_input_ids, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - size=pb.size, - max_sequence_length=pb.max_sequence_length, - ) + raise NotImplementedError @classmethod + @abstractmethod def concatenate(cls, batches: List["Batch"]) -> "Batch": - # Used for padding - total_batch_size = sum(batch.size for batch in batches) - max_sequence_length = max(batch.max_sequence_length for batch in batches) - # Only needed for Seq2SeqLM - max_encoded_sequence_length = None - - # Batch attributes - input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} - requests = [] - all_input_lengths = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - all_input_lengths.extend(batch.all_input_lengths) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - # Slicing end index for this batch - end_index = start_index + batch.size - - # We only concatenate batches that did at least one step - if batch.input_ids["input_ids"].shape[1] > 1: - raise ValueError("Batch input_ids should be of shape (batch_size, 1)") - - # Initialize tensors - if i == 0: - input_ids["input_ids"] = torch.empty( - (total_batch_size, 1), - dtype=batch.input_ids["input_ids"].dtype, - device=batch.input_ids["input_ids"].device, - ) - input_ids["attention_mask"] = torch.zeros( - (total_batch_size, max_sequence_length), - dtype=batch.input_ids["attention_mask"].dtype, - device=batch.input_ids["attention_mask"].device, - ) - - # input_ids["input_ids"] is always of shape [batch_size, 1] - # We do not need to pad it - input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] - - # We need to slice the attention mask to remove padding from previous steps - input_ids["attention_mask"][ - start_index:end_index, -batch.max_sequence_length: - ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:] - - for j, past in enumerate(batch.input_ids["past_key_values"]): - # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...] - head_dim, padded_sequence_length = past[0].shape[-2:] - num_heads = ( - past[0] - .view(batch.size, -1, head_dim, padded_sequence_length) - .shape[1] - ) - - # This will run only once per layer - if j == len(input_ids["past_key_values"]): - input_ids["past_key_values"].append([]) - - # Decoder past - for k, t in enumerate(past[:2]): - # Needed because BLOOM past shapes are not the same for keys and values - # Keys: [batch_size * num_heads, head_dim, seq_length] - # Values: [batch_size * num_heads, seq_length, head_dim] - head_dim_last = False - if t.shape[-2] == head_dim: - t = t.view( - batch.size, num_heads, head_dim, padded_sequence_length - ) - padded_t_shape = ( - total_batch_size, - num_heads, - head_dim, - max_sequence_length - 1, - ) - elif t.shape[-1] == head_dim: - head_dim_last = True - t = t.view( - batch.size, num_heads, padded_sequence_length, head_dim - ) - padded_t_shape = ( - total_batch_size, - num_heads, - max_sequence_length - 1, - head_dim, - ) - else: - raise ValueError(f"shape {t.shape} is not valid") - - # Initialize tensors - # This will run only once per layer and per past tensor - if k == len(input_ids["past_key_values"][j]): - input_ids["past_key_values"][j].append( - torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) - ) - - # We slice the past keys and values to remove the padding from previous batches - if not head_dim_last: - input_ids["past_key_values"][j][k][ - start_index:end_index, - :, - :, - -(batch.max_sequence_length - 1):, - ] = t[:, :, :, -(batch.max_sequence_length - 1):] - else: - input_ids["past_key_values"][j][k][ - start_index:end_index, - :, - -(batch.max_sequence_length - 1):, - :, - ] = t[:, :, -(batch.max_sequence_length - 1):, :] - - # Seq2SeqLM specific past (encoder past) - for k, t in enumerate(past[2:]): - if max_encoded_sequence_length is None: - max_encoded_sequence_length = max(max(batch.all_input_lengths) for batch in batches) - batch_max_encoded_sequence_length = max(batch.all_input_lengths) - - padded_t_shape = (total_batch_size, num_heads, max_encoded_sequence_length, head_dim) - - idx = k + 2 - - # Initialize tensors - # This will run only once per layer and per past tensor - if idx == len(input_ids["past_key_values"][j]): - input_ids["past_key_values"][j].append( - torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) - ) - - input_ids["past_key_values"][j][idx][ - start_index:end_index, - :, - -batch_max_encoded_sequence_length:, - : - ] = t[:, :, -batch_max_encoded_sequence_length:, :] - - start_index += batch.size - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - all_input_lengths=all_input_lengths, - input_ids=input_ids, - all_input_ids=all_input_ids, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - size=total_batch_size, - max_sequence_length=max_sequence_length, - ) + raise NotImplementedError @dataclass @@ -241,4 +34,6 @@ class GeneratedText: tokens: int def to_pb(self) -> generate_pb2.GeneratedText: - return generate_pb2.GeneratedText(request=self.request, output=self.output, tokens=self.tokens) + return generate_pb2.GeneratedText( + request=self.request, output=self.output, tokens=self.tokens + ) diff --git a/server/text_generation/server.py b/server/text_generation/server.py index fffeb0b..699cef0 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -9,7 +9,6 @@ from typing import List from text_generation.cache import Cache from text_generation.models import Model, get_model -from text_generation.models.types import Batch from text_generation.pb import generate_pb2_grpc, generate_pb2 @@ -27,7 +26,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.ClearCacheResponse() async def Generate(self, request, context): - batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device) + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.device + ) generated_texts, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) @@ -51,7 +52,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batches.append(batch) if len(batches) > 1: - batch = Batch.concatenate(batches) + batch = self.model.batch_type.concatenate(batches) else: batch = batches[0]