diff --git a/README.md b/README.md index 0c4f6f71..c7733d05 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,11 @@ A Rust and gRPC server for text generation inference. - [Safetensors](https://github.com/huggingface/safetensors) weight loading - 45ms per token generation for BLOOM with 8xA100 80GB -## Officially supported models +## Supported models - BLOOM - BLOOM-560m -Other models are supported on a best-effort basis using `AutoModelForCausalLM.from_pretrained(, torch_dtype=torch.float16, device_map="auto")`. - ## Load Tests for BLOOM See `k6/load_test.js` @@ -82,6 +80,7 @@ make router-dev ## TODO: +- [ ] Support AutoModelForSeq2SeqLM - [ ] Add tests for the `server/model` logic - [ ] Backport custom CUDA kernels to Transformers - [ ] Install safetensors with pip \ No newline at end of file diff --git a/server/Makefile b/server/Makefile index 4fa966e2..99764028 100644 --- a/server/Makefile +++ b/server/Makefile @@ -9,11 +9,11 @@ gen-server: install-transformers: # Install specific version of transformers rm transformers || true - rm transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 || true - curl -L -O https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip - unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip - rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip - mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers + rm transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 || true + curl -L -O https://github.com/OlivierDehaene/transformers/archive/7302a24535e8dc5637ea5b4e4572fc971d404098.zip + unzip 7302a24535e8dc5637ea5b4e4572fc971d404098.zip + rm 7302a24535e8dc5637ea5b4e4572fc971d404098.zip + mv transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers cd transformers && python setup.py install install-safetensors: diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 1f141c3c..55315ec6 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -1,22 +1,16 @@ from text_generation.models.model import Model -from text_generation.models.bloom import BLOOMSharded +from text_generation.models.bloom import BLOOM, BLOOMSharded -__all__ = ["Model", "BLOOMSharded"] +__all__ = ["Model", "BLOOM", "BLOOMSharded"] def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: - if model_name.startswith("bigscience/bloom"): if sharded: return BLOOMSharded(model_name, quantize) else: if quantize: raise ValueError("quantization is not supported for non-sharded BLOOM") - return Model(model_name) + return BLOOM(model_name) else: - if sharded: - raise ValueError("sharded is only supported for BLOOM models") - if quantize: - raise ValueError("Quantization is only supported for BLOOM models") - - return Model(model_name) + raise ValueError(f"model {model_name} is not supported yet") diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 172ca38d..ac2c6169 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -1,7 +1,7 @@ import torch import torch.distributed -from typing import List, Optional +from typing import List, Optional, Tuple, Type from accelerate import init_empty_weights from safetensors import safe_open @@ -11,8 +11,10 @@ from transformers.models.bloom.parallel_layers import ( TensorParallelEmbedding, TensorParallelRowLinear, ) +from transformers.modeling_outputs import CausalLMOutputWithPast from text_generation.models import Model +from text_generation.models.types import Batch, GeneratedText from text_generation.utils import ( initialize_torch_distributed, weight_files, @@ -29,9 +31,306 @@ except Exception as e: torch.manual_seed(0) -class BLOOMSharded(Model): +class BloomBatch(Batch): + @classmethod + def concatenate(cls, batches: List["Batch"]) -> "BloomBatch": + # Used for padding + total_batch_size = sum(batch.size for batch in batches) + max_sequence_length = max(batch.max_sequence_length for batch in batches) + + # Batch attributes + input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} + requests = [] + all_input_lengths = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + all_input_lengths.extend(batch.all_input_lengths) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Slicing end index for this batch + end_index = start_index + batch.size + + # We only concatenate batches that did at least one step + if batch.input_ids["input_ids"].shape[1] > 1: + raise ValueError("Batch input_ids should be of shape (batch_size, 1)") + + # Initialize tensors + if i == 0: + input_ids["input_ids"] = torch.empty( + (total_batch_size, 1), + dtype=batch.input_ids["input_ids"].dtype, + device=batch.input_ids["input_ids"].device, + ) + input_ids["attention_mask"] = torch.zeros( + (total_batch_size, max_sequence_length), + dtype=batch.input_ids["attention_mask"].dtype, + device=batch.input_ids["attention_mask"].device, + ) + + # input_ids["input_ids"] is always of shape [batch_size, 1] + # We do not need to pad it + input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] + + # We need to slice the attention mask to remove padding from previous steps + input_ids["attention_mask"][ + start_index:end_index, -batch.max_sequence_length: + ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:] + + for j, past in enumerate(batch.input_ids["past_key_values"]): + past_keys = past[0] + past_values = past[1] + + _, head_dim, padded_sequence_length = past_keys.shape + + # Reshape the tensors to make slicing easier + past_keys = past_keys.view( + batch.size, -1, head_dim, padded_sequence_length + ) + past_values = past_values.view( + batch.size, -1, padded_sequence_length, head_dim + ) + num_heads = past_keys.shape[1] + + # Initialize tensors + # This will run only once per layer + if j == len(input_ids["past_key_values"]): + padded_past_keys = torch.zeros( + ( + total_batch_size, + num_heads, + head_dim, + max_sequence_length - 1, + ), + dtype=past_keys.dtype, + device=past_keys.device, + ) + padded_past_values = torch.zeros( + ( + total_batch_size, + num_heads, + max_sequence_length - 1, + head_dim, + ), + dtype=past_values.dtype, + device=past_values.device, + ) + input_ids["past_key_values"].append( + [padded_past_keys, padded_past_values] + ) + + # We slice the past keys and values to remove the padding from previous batches + input_ids["past_key_values"][j][0][ + start_index:end_index, :, :, -(batch.max_sequence_length - 1): + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] + + input_ids["past_key_values"][j][1][ + start_index:end_index, :, -(batch.max_sequence_length - 1):, : + ] = past_values[:, :, -(batch.max_sequence_length - 1):, :] + + # If we are on the last batch, we need to reshape the tensors + if (i + 1) == len(batches): + input_ids["past_key_values"][j][0] = input_ids["past_key_values"][ + j + ][0].view(total_batch_size * num_heads, head_dim, -1) + input_ids["past_key_values"][j][1] = input_ids["past_key_values"][ + j + ][1].view(total_batch_size * num_heads, -1, head_dim) + + start_index += batch.size + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + all_input_lengths=all_input_lengths, + input_ids=input_ids, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=total_batch_size, + max_sequence_length=max_sequence_length, + ) + + +class BLOOM(Model): + def __init__(self, model_name: str): + if not model_name.startswith("bigscience/bloom"): + raise ValueError(f"Model {model_name} is not supported") + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + self.device = torch.device("cpu") + dtype = torch.float32 + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None + ).eval() + + self.num_heads = self.model.config.num_attention_heads + + @property + def batch_type(self) -> Type[BloomBatch]: + return BloomBatch + + def forward( + self, input_ids, attention_mask, past_key_values: Optional = None + ) -> CausalLMOutputWithPast: + # Model Forward + return self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + def generate_token( + self, batch: BloomBatch + ) -> Tuple[List[GeneratedText], Optional[BloomBatch]]: + # For some reason, inference_mode does not work well with GLOO which we use on CPU + context_manager = ( + torch.no_grad if self.device.type == "cpu" else torch.inference_mode + ) + with context_manager(): + outputs = self.forward(**batch.input_ids) + + # List of indices to cache + next_batch_keep_indices = [] + next_batch_past_keep_indices = [] + + # New input_ids for next forward + next_batch_input_ids = [] + next_batch_all_input_ids = [] + next_all_input_lengths = [] + + next_batch_size = 0 + next_batch_max_sequence_length = 0 + + # Finished requests + generated_texts: List[GeneratedText] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.all_input_lengths, + outputs.logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_tokens, + ) in enumerate(iterator): + # Select next token + next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + + # Append next token to all tokens + all_tokens = torch.cat([all_tokens, next_token]) + + # Evaluate stopping criteria + if stopping_criteria(all_tokens): + # Decode all tokens + output = self.tokenizer.decode( + all_tokens.squeeze(-1), skip_special_tokens=True + ) + # Add to the list of finished generations with the original request + generated_texts.append(GeneratedText(request, output)) + # add to the next batch + else: + next_batch_keep_indices.append(i) + # past_key_values is of shape [batch_size * num_heads, ...] + # so we need to take into account the `num_heads` stride here + next_batch_past_keep_indices.extend( + [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)] + ) + next_batch_input_ids.append(next_token) + next_batch_all_input_ids.append(all_tokens) + next_batch_size += 1 + new_input_length = input_length + 1 + next_all_input_lengths.append(new_input_length) + next_batch_max_sequence_length = max( + next_batch_max_sequence_length, new_input_length + ) + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generated_texts, None + + # If we finished at least one generation + next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} + if generated_texts: + # Apply indices to attention mask, past key values and other items that need to be cached + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ + next_batch_keep_indices + ] + next_batch_input_ids["past_key_values"] = [ + ( + keys[next_batch_past_keep_indices], + values[next_batch_past_keep_indices], + ) + for keys, values in outputs["past_key_values"] + ] + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] + next_batch_input_ids["past_key_values"] = outputs["past_key_values"] + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Update attention_mask with padding as we added a new token to input_ids + next_batch_input_ids["attention_mask"] = torch.cat( + [ + next_batch_input_ids["attention_mask"], + torch.ones((next_batch_size, 1)).to(self.device), + ], + dim=1, + ) + + next_batch = BloomBatch( + batch_id=batch.batch_id, + requests=next_batch_requests, + all_input_lengths=next_all_input_lengths, + input_ids=next_batch_input_ids, + all_input_ids=next_batch_all_input_ids, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + size=next_batch_size, + max_sequence_length=next_batch_max_sequence_length, + ) + return generated_texts, next_batch + + +class BLOOMSharded(BLOOM): def __init__(self, model_name: str, quantize: bool = False): super(Model, self).__init__() + if not model_name.startswith("bigscience/bloom"): + raise ValueError(f"Model {model_name} is not supported") + self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available(): @@ -80,17 +379,17 @@ class BLOOMSharded(Model): @staticmethod def load_weights( - model, - filenames: List[str], - quantize: bool, - device: torch.device, - rank: int, - world_size: int, + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): full_name = f"transformer.{name}" @@ -153,9 +452,9 @@ class BLOOMSharded(Model): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor.transpose(1, 0), diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index e585e476..ad094197 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -1,166 +1,19 @@ -import torch -import torch.distributed - -from typing import List, Tuple, Optional -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig -from transformers.modeling_outputs import CausalLMOutputWithPast +from abc import ABC, abstractmethod +from typing import List, Tuple, Optional, TypeVar, Type from text_generation.models.types import Batch, GeneratedText +B = TypeVar("B", bound=Batch) -class Model: - def __init__(self, model_name: str): - if torch.cuda.is_available(): - self.device = torch.device("cuda") - dtype = torch.float16 - else: - self.device = torch.device("cpu") - dtype = torch.float32 - self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None - ).eval() - - self.num_heads = self.model.config.num_attention_heads - - def forward( - self, input_ids, attention_mask, past_key_values: Optional = None - ) -> CausalLMOutputWithPast: - # Model Forward - return self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) +class Model(ABC): + @property + @abstractmethod + def batch_type(self) -> Type[B]: + raise NotImplementedError + @abstractmethod def generate_token( - self, batch: Batch - ) -> Tuple[List[GeneratedText], Optional[Batch]]: - # For some reason, inference_mode does not work well with GLOO which we use on CPU - context_manager = ( - torch.no_grad if self.device.type == "cpu" else torch.inference_mode - ) - with context_manager(): - outputs = self.forward(**batch.input_ids) - - # List of indices to cache - next_batch_keep_indices = [] - next_batch_past_keep_indices = [] - - # New input_ids for next forward - next_batch_input_ids = [] - next_batch_all_input_ids = [] - next_all_input_lengths = [] - - next_batch_size = 0 - next_batch_max_sequence_length = 0 - - # Finished requests - generated_texts: List[GeneratedText] = [] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.all_input_lengths, - outputs.logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_tokens, - ) in enumerate(iterator): - # Select next token - next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) - - # Append next token to all tokens - all_tokens = torch.cat([all_tokens, next_token]) - - # Evaluate stopping criteria - if stopping_criteria(all_tokens): - # Decode all tokens - output = self.tokenizer.decode( - all_tokens.squeeze(-1), skip_special_tokens=True - ) - # Add to the list of finished generations with the original request - generated_texts.append(GeneratedText(request, output)) - # add to the next batch - else: - next_batch_keep_indices.append(i) - # past_key_values is of shape [batch_size * num_heads, ...] - # so we need to take into account the `num_heads` stride here - next_batch_past_keep_indices.extend( - [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)] - ) - next_batch_input_ids.append(next_token) - next_batch_all_input_ids.append(all_tokens) - next_batch_size += 1 - new_input_length = input_length + 1 - next_all_input_lengths.append(new_input_length) - next_batch_max_sequence_length = max( - next_batch_max_sequence_length, new_input_length - ) - - # We finished all generations in the batch; there is no next batch - if not next_batch_keep_indices: - return generated_texts, None - - # If we finished at least one generation - next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} - if generated_texts: - # Apply indices to attention mask, past key values and other items that need to be cached - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ - next_batch_keep_indices - ] - next_batch_input_ids["past_key_values"] = [ - ( - keys[next_batch_past_keep_indices], - values[next_batch_past_keep_indices], - ) - for keys, values in outputs["past_key_values"] - ] - next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] - next_batch_next_token_choosers = [ - batch.next_token_choosers[i] for i in next_batch_keep_indices - ] - next_batch_stopping_criterias = [ - batch.stopping_criterias[i] for i in next_batch_keep_indices - ] - else: - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] - next_batch_input_ids["past_key_values"] = outputs["past_key_values"] - next_batch_requests = batch.requests - next_batch_next_token_choosers = batch.next_token_choosers - next_batch_stopping_criterias = batch.stopping_criterias - - # Update attention_mask with padding as we added a new token to input_ids - next_batch_input_ids["attention_mask"] = torch.cat( - [ - next_batch_input_ids["attention_mask"], - torch.ones((next_batch_size, 1)).to(self.device), - ], - dim=1, - ) - - next_batch = Batch( - batch_id=batch.batch_id, - requests=next_batch_requests, - all_input_lengths=next_all_input_lengths, - input_ids=next_batch_input_ids, - all_input_ids=next_batch_all_input_ids, - next_token_choosers=next_batch_next_token_choosers, - stopping_criterias=next_batch_stopping_criterias, - size=next_batch_size, - max_sequence_length=next_batch_max_sequence_length, - ) - return generated_texts, next_batch + self, batch: B + ) -> Tuple[List[GeneratedText], Optional[B]]: + raise NotImplementedError diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 39c33ab7..c2647a8d 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -1,5 +1,6 @@ import torch +from abc import abstractmethod from dataclasses import dataclass from typing import List, Dict @@ -70,131 +71,9 @@ class Batch: ) @classmethod + @abstractmethod def concatenate(cls, batches: List["Batch"]) -> "Batch": - # Used for padding - total_batch_size = sum(batch.size for batch in batches) - max_sequence_length = max(batch.max_sequence_length for batch in batches) - - # Batch attributes - input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} - requests = [] - all_input_lengths = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - all_input_lengths.extend(batch.all_input_lengths) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - # Slicing end index for this batch - end_index = start_index + batch.size - - # We only concatenate batches that did at least one step - if batch.input_ids["input_ids"].shape[1] > 1: - raise ValueError("Batch input_ids should be of shape (batch_size, 1)") - - # Initialize tensors - if i == 0: - input_ids["input_ids"] = torch.empty( - (total_batch_size, 1), - dtype=batch.input_ids["input_ids"].dtype, - device=batch.input_ids["input_ids"].device, - ) - input_ids["attention_mask"] = torch.zeros( - (total_batch_size, max_sequence_length), - dtype=batch.input_ids["attention_mask"].dtype, - device=batch.input_ids["attention_mask"].device, - ) - - # input_ids["input_ids"] is always of shape [batch_size, 1] - # We do not need to pad it - input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] - - # We need to slice the attention mask to remove padding from previous steps - input_ids["attention_mask"][ - start_index:end_index, -batch.max_sequence_length : - ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] - - for j, past in enumerate(batch.input_ids["past_key_values"]): - past_keys = past[0] - past_values = past[1] - - _, head_dim, padded_sequence_length = past_keys.shape - - # Reshape the tensors to make slicing easier - past_keys = past_keys.view( - batch.size, -1, head_dim, padded_sequence_length - ) - past_values = past_values.view( - batch.size, -1, padded_sequence_length, head_dim - ) - num_heads = past_keys.shape[1] - - # Initialize tensors - # This will run only once per layer - if j == len(input_ids["past_key_values"]): - padded_past_keys = torch.zeros( - ( - total_batch_size, - num_heads, - head_dim, - max_sequence_length - 1, - ), - dtype=past_keys.dtype, - device=past_keys.device, - ) - padded_past_values = torch.zeros( - ( - total_batch_size, - num_heads, - max_sequence_length - 1, - head_dim, - ), - dtype=past_values.dtype, - device=past_values.device, - ) - input_ids["past_key_values"].append( - [padded_past_keys, padded_past_values] - ) - - # We slice the past keys and values to remove the padding from previous batches - input_ids["past_key_values"][j][0][ - start_index:end_index, :, :, -(batch.max_sequence_length - 1) : - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] - - input_ids["past_key_values"][j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1) :, : - ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] - - # If we are on the last batch, we need to reshape the tensors - if (i + 1) == len(batches): - input_ids["past_key_values"][j][0] = input_ids["past_key_values"][ - j - ][0].view(total_batch_size * num_heads, head_dim, -1) - input_ids["past_key_values"][j][1] = input_ids["past_key_values"][ - j - ][1].view(total_batch_size * num_heads, -1, head_dim) - - start_index += batch.size - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - all_input_lengths=all_input_lengths, - input_ids=input_ids, - all_input_ids=all_input_ids, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - size=total_batch_size, - max_sequence_length=max_sequence_length, - ) + raise NotImplementedError @dataclass diff --git a/server/text_generation/server.py b/server/text_generation/server.py index fffeb0ba..b2b34cb1 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -27,7 +27,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.ClearCacheResponse() async def Generate(self, request, context): - batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device) + batch = self.model.batch_type.from_pb(request.batch, self.model.tokenizer, self.model.device) generated_texts, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) @@ -51,7 +51,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batches.append(batch) if len(batches) > 1: - batch = Batch.concatenate(batches) + batch = self.model.batch_type.concatenate(batches) else: batch = batches[0]