fix(models): Revert buggy support for AutoModel

This commit is contained in:
OlivierDehaene 2022-11-03 16:07:54 +01:00
parent b3b7ea0d74
commit 755fc0e403
7 changed files with 339 additions and 315 deletions

View File

@ -15,13 +15,11 @@ A Rust and gRPC server for text generation inference.
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB - 45ms per token generation for BLOOM with 8xA100 80GB
## Officially supported models ## Supported models
- BLOOM - BLOOM
- BLOOM-560m - BLOOM-560m
Other models are supported on a best-effort basis using `AutoModelForCausalLM.from_pretrained(<model>, torch_dtype=torch.float16, device_map="auto")`.
## Load Tests for BLOOM ## Load Tests for BLOOM
See `k6/load_test.js` See `k6/load_test.js`
@ -82,6 +80,7 @@ make router-dev
## TODO: ## TODO:
- [ ] Support AutoModelForSeq2SeqLM
- [ ] Add tests for the `server/model` logic - [ ] Add tests for the `server/model` logic
- [ ] Backport custom CUDA kernels to Transformers - [ ] Backport custom CUDA kernels to Transformers
- [ ] Install safetensors with pip - [ ] Install safetensors with pip

View File

@ -9,11 +9,11 @@ gen-server:
install-transformers: install-transformers:
# Install specific version of transformers # Install specific version of transformers
rm transformers || true rm transformers || true
rm transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 || true rm transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 || true
curl -L -O https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip curl -L -O https://github.com/OlivierDehaene/transformers/archive/7302a24535e8dc5637ea5b4e4572fc971d404098.zip
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip unzip 7302a24535e8dc5637ea5b4e4572fc971d404098.zip
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip rm 7302a24535e8dc5637ea5b4e4572fc971d404098.zip
mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers mv transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers
cd transformers && python setup.py install cd transformers && python setup.py install
install-safetensors: install-safetensors:

View File

@ -1,22 +1,16 @@
from text_generation.models.model import Model from text_generation.models.model import Model
from text_generation.models.bloom import BLOOMSharded from text_generation.models.bloom import BLOOM, BLOOMSharded
__all__ = ["Model", "BLOOMSharded"] __all__ = ["Model", "BLOOM", "BLOOMSharded"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
if model_name.startswith("bigscience/bloom"): if model_name.startswith("bigscience/bloom"):
if sharded: if sharded:
return BLOOMSharded(model_name, quantize) return BLOOMSharded(model_name, quantize)
else: else:
if quantize: if quantize:
raise ValueError("quantization is not supported for non-sharded BLOOM") raise ValueError("quantization is not supported for non-sharded BLOOM")
return Model(model_name) return BLOOM(model_name)
else: else:
if sharded: raise ValueError(f"model {model_name} is not supported yet")
raise ValueError("sharded is only supported for BLOOM models")
if quantize:
raise ValueError("Quantization is only supported for BLOOM models")
return Model(model_name)

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional from typing import List, Optional, Tuple, Type
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
@ -11,8 +11,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import Batch, GeneratedText
from text_generation.utils import ( from text_generation.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -29,9 +31,306 @@ except Exception as e:
torch.manual_seed(0) torch.manual_seed(0)
class BLOOMSharded(Model): class BloomBatch(Batch):
@classmethod
def concatenate(cls, batches: List["Batch"]) -> "BloomBatch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length:
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:]
for j, past in enumerate(batch.input_ids["past_key_values"]):
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
# Reshape the tensors to make slicing easier
past_keys = past_keys.view(
batch.size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch.size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
# Initialize tensors
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1):
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1):, :
] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
# If we are on the last batch, we need to reshape the tensors
if (i + 1) == len(batches):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
class BLOOM(Model):
def __init__(self, model_name: str):
if not model_name.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported")
if torch.cuda.is_available():
self.device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None
).eval()
self.num_heads = self.model.config.num_attention_heads
@property
def batch_type(self) -> Type[BloomBatch]:
return BloomBatch
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
) -> CausalLMOutputWithPast:
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
def generate_token(
self, batch: BloomBatch
) -> Tuple[List[GeneratedText], Optional[BloomBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
next_batch_past_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(GeneratedText(request, output))
# add to the next batch
else:
next_batch_keep_indices.append(i)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices.extend(
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
next_batch_input_ids["past_key_values"] = [
(
keys[next_batch_past_keep_indices],
values[next_batch_past_keep_indices],
)
for keys, values in outputs["past_key_values"]
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = BloomBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, quantize: bool = False): def __init__(self, model_name: str, quantize: bool = False):
super(Model, self).__init__() super(Model, self).__init__()
if not model_name.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -80,17 +379,17 @@ class BLOOMSharded(Model):
@staticmethod @staticmethod
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool, quantize: bool,
device: torch.device, device: torch.device,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
parameters = dict(model.named_parameters()) parameters = dict(model.named_parameters())
for file in filenames: for file in filenames:
with safe_open( with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu" file, framework="pt", device=str(device) if not quantize else "cpu"
) as f: ) as f:
for name in f.keys(): for name in f.keys():
full_name = f"transformer.{name}" full_name = f"transformer.{name}"
@ -153,9 +452,9 @@ class BLOOMSharded(Model):
) )
if ( if (
type(module) type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear] in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight" and param_name == "weight"
): ):
tensor = Int8Params( tensor = Int8Params(
tensor.transpose(1, 0), tensor.transpose(1, 0),

View File

@ -1,166 +1,19 @@
import torch from abc import ABC, abstractmethod
import torch.distributed from typing import List, Tuple, Optional, TypeVar, Type
from typing import List, Tuple, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation.models.types import Batch, GeneratedText from text_generation.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model:
def __init__(self, model_name: str):
if torch.cuda.is_available():
self.device = torch.device("cuda")
dtype = torch.float16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") class Model(ABC):
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) @property
self.model = AutoModelForCausalLM.from_pretrained( @abstractmethod
model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None def batch_type(self) -> Type[B]:
).eval() raise NotImplementedError
self.num_heads = self.model.config.num_attention_heads
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
) -> CausalLMOutputWithPast:
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
@abstractmethod
def generate_token( def generate_token(
self, batch: Batch self, batch: B
) -> Tuple[List[GeneratedText], Optional[Batch]]: ) -> Tuple[List[GeneratedText], Optional[B]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU raise NotImplementedError
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
next_batch_past_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(GeneratedText(request, output))
# add to the next batch
else:
next_batch_keep_indices.append(i)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices.extend(
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
next_batch_input_ids["past_key_values"] = [
(
keys[next_batch_past_keep_indices],
values[next_batch_past_keep_indices],
)
for keys, values in outputs["past_key_values"]
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = Batch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch

View File

@ -1,5 +1,6 @@
import torch import torch
from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict from typing import List, Dict
@ -70,131 +71,9 @@ class Batch:
) )
@classmethod @classmethod
@abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
# Used for padding raise NotImplementedError
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length :
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
for j, past in enumerate(batch.input_ids["past_key_values"]):
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
# Reshape the tensors to make slicing easier
past_keys = past_keys.view(
batch.size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch.size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
# Initialize tensors
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
# If we are on the last batch, we need to reshape the tensors
if (i + 1) == len(batches):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
@dataclass @dataclass

View File

@ -27,7 +27,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def Generate(self, request, context): async def Generate(self, request, context):
batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device) batch = self.model.batch_type.from_pb(request.batch, self.model.tokenizer, self.model.device)
generated_texts, next_batch = self.model.generate_token(batch) generated_texts, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
@ -51,7 +51,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batches.append(batch) batches.append(batch)
if len(batches) > 1: if len(batches) > 1:
batch = Batch.concatenate(batches) batch = self.model.batch_type.concatenate(batches)
else: else:
batch = batches[0] batch = batches[0]