From c5665f5c8be801237b0f7bac34da42371a43b9ad Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 4 Nov 2022 14:22:47 +0100 Subject: [PATCH] feat(server): Support generic AutoModelForCausalLM --- README.md | 1 + proto/generate.proto | 2 + router/src/batcher.rs | 2 + router/src/server.rs | 2 +- server/text_generation/models/__init__.py | 13 +- server/text_generation/models/bloom.py | 315 +-------------------- server/text_generation/models/causal_lm.py | 38 +++ server/text_generation/models/model.py | 140 ++++++++- server/text_generation/models/types.py | 171 ++++++++++- server/text_generation/server.py | 4 +- server/text_generation/utils.py | 18 +- 11 files changed, 373 insertions(+), 333 deletions(-) create mode 100644 server/text_generation/models/causal_lm.py diff --git a/README.md b/README.md index c7733d0..29e5835 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ A Rust and gRPC server for text generation inference. ## Supported models - BLOOM +- BLOOMZ - BLOOM-560m ## Load Tests for BLOOM diff --git a/proto/generate.proto b/proto/generate.proto index 45afca8..68dfa15 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -63,6 +63,8 @@ message GeneratedText { Request request = 1; /// Output string output = 2; + /// Number of generated tokens + uint32 tokens = 3; } message GenerateRequest { diff --git a/router/src/batcher.rs b/router/src/batcher.rs index e381986..074c54b 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -190,6 +190,7 @@ fn send_generated(finished: Vec, db: &Db) { .expect("ID not found in db. This is a bug."); let response = InferResponse { output: output.output, + tokens: output.tokens, queued: entry.time, start: entry.batch_time.unwrap(), // unwrap is always valid end: Instant::now(), @@ -202,6 +203,7 @@ fn send_generated(finished: Vec, db: &Db) { #[derive(Debug)] pub(crate) struct InferResponse { pub(crate) output: String, + pub(crate) tokens: u32, pub(crate) queued: Instant, pub(crate) start: Instant, pub(crate) end: Instant, diff --git a/router/src/server.rs b/router/src/server.rs index d31c949..72b720e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -116,7 +116,7 @@ async fn generate( let validation_time = response.queued - start_time; let queue_time = response.start - response.queued; let inference_time = response.end - response.start; - let time_per_token = inference_time / req.parameters.max_new_tokens; + let time_per_token = inference_time / response.tokens; // Headers let mut headers = HeaderMap::new(); diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 55315ec..0371ab4 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -1,7 +1,8 @@ from text_generation.models.model import Model -from text_generation.models.bloom import BLOOM, BLOOMSharded +from text_generation.models.bloom import BLOOMSharded +from text_generation.models.causal_lm import CausalLM -__all__ = ["Model", "BLOOM", "BLOOMSharded"] +__all__ = ["Model", "BLOOMSharded", "CausalLM"] def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: @@ -11,6 +12,10 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: else: if quantize: raise ValueError("quantization is not supported for non-sharded BLOOM") - return BLOOM(model_name) + return CausalLM(model_name) else: - raise ValueError(f"model {model_name} is not supported yet") + if sharded: + raise ValueError("sharded is not supported for AutoModel") + if quantize: + raise ValueError("quantize is not supported for AutoModel") + return CausalLM(model_name) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index ac2c616..37f55f2 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -1,7 +1,7 @@ import torch import torch.distributed -from typing import List, Optional, Tuple, Type +from typing import List, Optional from accelerate import init_empty_weights from safetensors import safe_open @@ -11,10 +11,8 @@ from transformers.models.bloom.parallel_layers import ( TensorParallelEmbedding, TensorParallelRowLinear, ) -from transformers.modeling_outputs import CausalLMOutputWithPast from text_generation.models import Model -from text_generation.models.types import Batch, GeneratedText from text_generation.utils import ( initialize_torch_distributed, weight_files, @@ -31,322 +29,26 @@ except Exception as e: torch.manual_seed(0) -class BloomBatch(Batch): - @classmethod - def concatenate(cls, batches: List["Batch"]) -> "BloomBatch": - # Used for padding - total_batch_size = sum(batch.size for batch in batches) - max_sequence_length = max(batch.max_sequence_length for batch in batches) - - # Batch attributes - input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} - requests = [] - all_input_lengths = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - all_input_lengths.extend(batch.all_input_lengths) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - # Slicing end index for this batch - end_index = start_index + batch.size - - # We only concatenate batches that did at least one step - if batch.input_ids["input_ids"].shape[1] > 1: - raise ValueError("Batch input_ids should be of shape (batch_size, 1)") - - # Initialize tensors - if i == 0: - input_ids["input_ids"] = torch.empty( - (total_batch_size, 1), - dtype=batch.input_ids["input_ids"].dtype, - device=batch.input_ids["input_ids"].device, - ) - input_ids["attention_mask"] = torch.zeros( - (total_batch_size, max_sequence_length), - dtype=batch.input_ids["attention_mask"].dtype, - device=batch.input_ids["attention_mask"].device, - ) - - # input_ids["input_ids"] is always of shape [batch_size, 1] - # We do not need to pad it - input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] - - # We need to slice the attention mask to remove padding from previous steps - input_ids["attention_mask"][ - start_index:end_index, -batch.max_sequence_length: - ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:] - - for j, past in enumerate(batch.input_ids["past_key_values"]): - past_keys = past[0] - past_values = past[1] - - _, head_dim, padded_sequence_length = past_keys.shape - - # Reshape the tensors to make slicing easier - past_keys = past_keys.view( - batch.size, -1, head_dim, padded_sequence_length - ) - past_values = past_values.view( - batch.size, -1, padded_sequence_length, head_dim - ) - num_heads = past_keys.shape[1] - - # Initialize tensors - # This will run only once per layer - if j == len(input_ids["past_key_values"]): - padded_past_keys = torch.zeros( - ( - total_batch_size, - num_heads, - head_dim, - max_sequence_length - 1, - ), - dtype=past_keys.dtype, - device=past_keys.device, - ) - padded_past_values = torch.zeros( - ( - total_batch_size, - num_heads, - max_sequence_length - 1, - head_dim, - ), - dtype=past_values.dtype, - device=past_values.device, - ) - input_ids["past_key_values"].append( - [padded_past_keys, padded_past_values] - ) - - # We slice the past keys and values to remove the padding from previous batches - input_ids["past_key_values"][j][0][ - start_index:end_index, :, :, -(batch.max_sequence_length - 1): - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] - - input_ids["past_key_values"][j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1):, : - ] = past_values[:, :, -(batch.max_sequence_length - 1):, :] - - # If we are on the last batch, we need to reshape the tensors - if (i + 1) == len(batches): - input_ids["past_key_values"][j][0] = input_ids["past_key_values"][ - j - ][0].view(total_batch_size * num_heads, head_dim, -1) - input_ids["past_key_values"][j][1] = input_ids["past_key_values"][ - j - ][1].view(total_batch_size * num_heads, -1, head_dim) - - start_index += batch.size - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - all_input_lengths=all_input_lengths, - input_ids=input_ids, - all_input_ids=all_input_ids, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - size=total_batch_size, - max_sequence_length=max_sequence_length, - ) - - -class BLOOM(Model): - def __init__(self, model_name: str): - if not model_name.startswith("bigscience/bloom"): - raise ValueError(f"Model {model_name} is not supported") - - if torch.cuda.is_available(): - self.device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - else: - self.device = torch.device("cpu") - dtype = torch.float32 - - self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None - ).eval() - - self.num_heads = self.model.config.num_attention_heads - - @property - def batch_type(self) -> Type[BloomBatch]: - return BloomBatch - - def forward( - self, input_ids, attention_mask, past_key_values: Optional = None - ) -> CausalLMOutputWithPast: - # Model Forward - return self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - def generate_token( - self, batch: BloomBatch - ) -> Tuple[List[GeneratedText], Optional[BloomBatch]]: - # For some reason, inference_mode does not work well with GLOO which we use on CPU - context_manager = ( - torch.no_grad if self.device.type == "cpu" else torch.inference_mode - ) - with context_manager(): - outputs = self.forward(**batch.input_ids) - - # List of indices to cache - next_batch_keep_indices = [] - next_batch_past_keep_indices = [] - - # New input_ids for next forward - next_batch_input_ids = [] - next_batch_all_input_ids = [] - next_all_input_lengths = [] - - next_batch_size = 0 - next_batch_max_sequence_length = 0 - - # Finished requests - generated_texts: List[GeneratedText] = [] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.all_input_lengths, - outputs.logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_tokens, - ) in enumerate(iterator): - # Select next token - next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) - - # Append next token to all tokens - all_tokens = torch.cat([all_tokens, next_token]) - - # Evaluate stopping criteria - if stopping_criteria(all_tokens): - # Decode all tokens - output = self.tokenizer.decode( - all_tokens.squeeze(-1), skip_special_tokens=True - ) - # Add to the list of finished generations with the original request - generated_texts.append(GeneratedText(request, output)) - # add to the next batch - else: - next_batch_keep_indices.append(i) - # past_key_values is of shape [batch_size * num_heads, ...] - # so we need to take into account the `num_heads` stride here - next_batch_past_keep_indices.extend( - [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)] - ) - next_batch_input_ids.append(next_token) - next_batch_all_input_ids.append(all_tokens) - next_batch_size += 1 - new_input_length = input_length + 1 - next_all_input_lengths.append(new_input_length) - next_batch_max_sequence_length = max( - next_batch_max_sequence_length, new_input_length - ) - - # We finished all generations in the batch; there is no next batch - if not next_batch_keep_indices: - return generated_texts, None - - # If we finished at least one generation - next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} - if generated_texts: - # Apply indices to attention mask, past key values and other items that need to be cached - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ - next_batch_keep_indices - ] - next_batch_input_ids["past_key_values"] = [ - ( - keys[next_batch_past_keep_indices], - values[next_batch_past_keep_indices], - ) - for keys, values in outputs["past_key_values"] - ] - next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] - next_batch_next_token_choosers = [ - batch.next_token_choosers[i] for i in next_batch_keep_indices - ] - next_batch_stopping_criterias = [ - batch.stopping_criterias[i] for i in next_batch_keep_indices - ] - else: - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] - next_batch_input_ids["past_key_values"] = outputs["past_key_values"] - next_batch_requests = batch.requests - next_batch_next_token_choosers = batch.next_token_choosers - next_batch_stopping_criterias = batch.stopping_criterias - - # Update attention_mask with padding as we added a new token to input_ids - next_batch_input_ids["attention_mask"] = torch.cat( - [ - next_batch_input_ids["attention_mask"], - torch.ones((next_batch_size, 1)).to(self.device), - ], - dim=1, - ) - - next_batch = BloomBatch( - batch_id=batch.batch_id, - requests=next_batch_requests, - all_input_lengths=next_all_input_lengths, - input_ids=next_batch_input_ids, - all_input_ids=next_batch_all_input_ids, - next_token_choosers=next_batch_next_token_choosers, - stopping_criterias=next_batch_stopping_criterias, - size=next_batch_size, - max_sequence_length=next_batch_max_sequence_length, - ) - return generated_texts, next_batch - - -class BLOOMSharded(BLOOM): +class BLOOMSharded(Model): def __init__(self, model_name: str, quantize: bool = False): - super(Model, self).__init__() if not model_name.startswith("bigscience/bloom"): raise ValueError(f"Model {model_name} is not supported") self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available(): - self.device = torch.device(f"cuda:{self.rank}") + device = torch.device(f"cuda:{self.rank}") dtype = torch.float16 else: - self.device = torch.device("cpu") + device = torch.device("cpu") dtype = torch.float32 - self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") config = AutoConfig.from_pretrained( model_name, slow_but_exact=False, tp_parallel=True ) config.pad_token_id = 3 - self.num_heads = config.n_head // self.process_group.size() # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -370,12 +72,14 @@ class BLOOMSharded(BLOOM): model, filenames, quantize=quantize, - device=self.device, + device=device, rank=self.rank, world_size=self.world_size, ) self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) + super(BLOOMSharded, self).__init__(tokenizer=tokenizer, num_heads=config.n_head // self.process_group.size(), + device=device) @staticmethod def load_weights( @@ -526,5 +230,4 @@ class BLOOMSharded(BLOOM): torch.distributed.all_gather(logits, logits_shard, group=self.process_group) logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size) - outputs.logits = logits - return outputs + return logits, outputs.past_key_values diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py new file mode 100644 index 0000000..39bb450 --- /dev/null +++ b/server/text_generation/models/causal_lm.py @@ -0,0 +1,38 @@ +import torch + +from transformers import AutoTokenizer, AutoModelForCausalLM +from typing import Optional, Tuple, List + +from text_generation.models import Model + + +class CausalLM(Model): + def __init__(self, model_name: str): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + else: + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=dtype, + device_map="auto" if torch.cuda.is_available() else None, + ).eval() + + super(CausalLM, self).__init__(tokenizer=tokenizer, num_heads=self.model.config.num_attention_heads, device=device) + + def forward( + self, input_ids, attention_mask, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + return outputs.logits, outputs.past_key_values diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index ad09419..d6b3c73 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -1,19 +1,139 @@ +import torch + from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, TypeVar, Type +from typing import List, Tuple, Optional +from tokenizers import Tokenizer from text_generation.models.types import Batch, GeneratedText -B = TypeVar("B", bound=Batch) - class Model(ABC): - @property - @abstractmethod - def batch_type(self) -> Type[B]: - raise NotImplementedError + def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device): + self.tokenizer = tokenizer + self.num_heads = num_heads + self.device = device @abstractmethod - def generate_token( - self, batch: B - ) -> Tuple[List[GeneratedText], Optional[B]]: + def forward(self, input_ids, attention_mask, past_key_values: Optional = None) -> Tuple[torch.Tensor, List[Tuple]]: raise NotImplementedError + + def generate_token( + self, batch: Batch + ) -> Tuple[List[GeneratedText], Optional[Batch]]: + # For some reason, inference_mode does not work well with GLOO which we use on CPU + context_manager = ( + torch.no_grad if self.device.type == "cpu" else torch.inference_mode + ) + with context_manager(): + logits, past = self.forward(**batch.input_ids) + + # List of indices to cache + next_batch_keep_indices = [] + + # New input_ids for next forward + next_batch_input_ids = [] + next_batch_all_input_ids = [] + next_all_input_lengths = [] + + next_batch_size = 0 + next_batch_max_sequence_length = 0 + + # Finished requests + generated_texts: List[GeneratedText] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.all_input_lengths, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_tokens, + ) in enumerate(iterator): + # Select next token + next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + + # Append next token to all tokens + all_tokens = torch.cat([all_tokens, next_token]) + + # Evaluate stopping criteria + if stopping_criteria(all_tokens): + # Decode all tokens + output = self.tokenizer.decode( + all_tokens.squeeze(-1), skip_special_tokens=True + ) + # Add to the list of finished generations with the original request + generated_texts.append(GeneratedText(request, output, stopping_criteria.current_tokens)) + # add to the next batch + else: + next_batch_keep_indices.append(i) + next_batch_input_ids.append(next_token) + next_batch_all_input_ids.append(all_tokens) + next_batch_size += 1 + new_input_length = input_length + 1 + next_all_input_lengths.append(new_input_length) + next_batch_max_sequence_length = max( + next_batch_max_sequence_length, new_input_length + ) + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generated_texts, None + + # If we finished at least one generation + next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} + if generated_texts: + # Apply indices to attention mask, past key values and other items that need to be cached + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ + next_batch_keep_indices + ] + # Force past to be of dim [batch_size, num_heads, ...] for easy indexing + next_batch_input_ids["past_key_values"] = [ + [t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] for t in layer] + for layer in past + ] + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] + next_batch_input_ids["past_key_values"] = past + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Update attention_mask with padding as we added a new token to input_ids + next_batch_input_ids["attention_mask"] = torch.cat( + [ + next_batch_input_ids["attention_mask"], + torch.ones((next_batch_size, 1)).to(self.device), + ], + dim=1, + ) + + next_batch = Batch( + batch_id=batch.batch_id, + requests=next_batch_requests, + all_input_lengths=next_all_input_lengths, + input_ids=next_batch_input_ids, + all_input_ids=next_batch_all_input_ids, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + size=next_batch_size, + max_sequence_length=next_batch_max_sequence_length, + ) + return generated_texts, next_batch diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index c2647a8..0c17e55 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -1,6 +1,5 @@ import torch -from abc import abstractmethod from dataclasses import dataclass from typing import List, Dict @@ -32,7 +31,7 @@ class Batch: @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "Batch": inputs = [] next_token_choosers = [] @@ -51,7 +50,11 @@ class Batch: do_sample=r.parameters.do_sample, ) ) - stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) + stopping_criterias.append( + StoppingCriteria( + eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens + ) + ) input_ids = tokenizer( inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 @@ -71,15 +74,171 @@ class Batch: ) @classmethod - @abstractmethod def concatenate(cls, batches: List["Batch"]) -> "Batch": - raise NotImplementedError + # Used for padding + total_batch_size = sum(batch.size for batch in batches) + max_sequence_length = max(batch.max_sequence_length for batch in batches) + # Only needed for Seq2SeqLM + max_encoded_sequence_length = None + + # Batch attributes + input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} + requests = [] + all_input_lengths = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + all_input_lengths.extend(batch.all_input_lengths) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Slicing end index for this batch + end_index = start_index + batch.size + + # We only concatenate batches that did at least one step + if batch.input_ids["input_ids"].shape[1] > 1: + raise ValueError("Batch input_ids should be of shape (batch_size, 1)") + + # Initialize tensors + if i == 0: + input_ids["input_ids"] = torch.empty( + (total_batch_size, 1), + dtype=batch.input_ids["input_ids"].dtype, + device=batch.input_ids["input_ids"].device, + ) + input_ids["attention_mask"] = torch.zeros( + (total_batch_size, max_sequence_length), + dtype=batch.input_ids["attention_mask"].dtype, + device=batch.input_ids["attention_mask"].device, + ) + + # input_ids["input_ids"] is always of shape [batch_size, 1] + # We do not need to pad it + input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] + + # We need to slice the attention mask to remove padding from previous steps + input_ids["attention_mask"][ + start_index:end_index, -batch.max_sequence_length: + ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:] + + for j, past in enumerate(batch.input_ids["past_key_values"]): + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape + # BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...] + head_dim, padded_sequence_length = past[0].shape[-2:] + num_heads = ( + past[0] + .view(batch.size, -1, head_dim, padded_sequence_length) + .shape[1] + ) + + # This will run only once per layer + if j == len(input_ids["past_key_values"]): + input_ids["past_key_values"].append([]) + + # Decoder past + for k, t in enumerate(past[:2]): + # Needed because BLOOM past shapes are not the same for keys and values + # Keys: [batch_size * num_heads, head_dim, seq_length] + # Values: [batch_size * num_heads, seq_length, head_dim] + head_dim_last = False + if t.shape[-2] == head_dim: + t = t.view( + batch.size, num_heads, head_dim, padded_sequence_length + ) + padded_t_shape = ( + total_batch_size, + num_heads, + head_dim, + max_sequence_length - 1, + ) + elif t.shape[-1] == head_dim: + head_dim_last = True + t = t.view( + batch.size, num_heads, padded_sequence_length, head_dim + ) + padded_t_shape = ( + total_batch_size, + num_heads, + max_sequence_length - 1, + head_dim, + ) + else: + raise ValueError(f"shape {t.shape} is not valid") + + # Initialize tensors + # This will run only once per layer and per past tensor + if k == len(input_ids["past_key_values"][j]): + input_ids["past_key_values"][j].append( + torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) + ) + + # We slice the past keys and values to remove the padding from previous batches + if not head_dim_last: + input_ids["past_key_values"][j][k][ + start_index:end_index, + :, + :, + -(batch.max_sequence_length - 1):, + ] = t[:, :, :, -(batch.max_sequence_length - 1):] + else: + input_ids["past_key_values"][j][k][ + start_index:end_index, + :, + -(batch.max_sequence_length - 1):, + :, + ] = t[:, :, -(batch.max_sequence_length - 1):, :] + + # Seq2SeqLM specific past (encoder past) + for k, t in enumerate(past[2:]): + if max_encoded_sequence_length is None: + max_encoded_sequence_length = max(max(batch.all_input_lengths) for batch in batches) + batch_max_encoded_sequence_length = max(batch.all_input_lengths) + + padded_t_shape = (total_batch_size, num_heads, max_encoded_sequence_length, head_dim) + + idx = k + 2 + + # Initialize tensors + # This will run only once per layer and per past tensor + if idx == len(input_ids["past_key_values"][j]): + input_ids["past_key_values"][j].append( + torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) + ) + + input_ids["past_key_values"][j][idx][ + start_index:end_index, + :, + -batch_max_encoded_sequence_length:, + : + ] = t[:, :, -batch_max_encoded_sequence_length:, :] + + start_index += batch.size + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + all_input_lengths=all_input_lengths, + input_ids=input_ids, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=total_batch_size, + max_sequence_length=max_sequence_length, + ) @dataclass class GeneratedText: request: generate_pb2.Request output: str + tokens: int def to_pb(self) -> generate_pb2.GeneratedText: - return generate_pb2.GeneratedText(request=self.request, output=self.output) + return generate_pb2.GeneratedText(request=self.request, output=self.output, tokens=self.tokens) diff --git a/server/text_generation/server.py b/server/text_generation/server.py index b2b34cb..fffeb0b 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -27,7 +27,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.ClearCacheResponse() async def Generate(self, request, context): - batch = self.model.batch_type.from_pb(request.batch, self.model.tokenizer, self.model.device) + batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device) generated_texts, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) @@ -51,7 +51,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batches.append(batch) if len(batches) > 1: - batch = self.model.batch_type.concatenate(batches) + batch = Batch.concatenate(batches) else: batch = batches[0] diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 50cd412..0e2d9ae 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -58,7 +58,8 @@ class NextTokenChooser: class StoppingCriteria: - def __init__(self, max_new_tokens=20): + def __init__(self, eos_token_id, max_new_tokens=20): + self.eos_token_id = eos_token_id self.max_new_tokens = max_new_tokens self.current_tokens = 0 @@ -66,6 +67,8 @@ class StoppingCriteria: self.current_tokens += 1 if self.current_tokens >= self.max_new_tokens: return True + if self.eos_token_id is not None and all_ids[-1] == self.eos_token_id: + return True return False @@ -124,11 +127,18 @@ def download_weights(model_name, extension=".safetensors"): filenames = weight_hub_files(model_name, extension) download_function = partial( - hf_hub_download, repo_id=model_name, local_files_only=False + hf_hub_download, + repo_id=model_name, + local_files_only=False, ) executor = ThreadPoolExecutor(max_workers=5) - futures = [executor.submit(download_function, filename=filename) for filename in filenames] - files = [file for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures))] + futures = [ + executor.submit(download_function, filename=filename) for filename in filenames + ] + files = [ + file + for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) + ] return files