feat(server): Support generic AutoModelForCausalLM

This commit is contained in:
OlivierDehaene 2022-11-04 14:22:47 +01:00
parent 755fc0e403
commit c5665f5c8b
11 changed files with 373 additions and 333 deletions

View File

@ -18,6 +18,7 @@ A Rust and gRPC server for text generation inference.
## Supported models ## Supported models
- BLOOM - BLOOM
- BLOOMZ
- BLOOM-560m - BLOOM-560m
## Load Tests for BLOOM ## Load Tests for BLOOM

View File

@ -63,6 +63,8 @@ message GeneratedText {
Request request = 1; Request request = 1;
/// Output /// Output
string output = 2; string output = 2;
/// Number of generated tokens
uint32 tokens = 3;
} }
message GenerateRequest { message GenerateRequest {

View File

@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
.expect("ID not found in db. This is a bug."); .expect("ID not found in db. This is a bug.");
let response = InferResponse { let response = InferResponse {
output: output.output, output: output.output,
tokens: output.tokens,
queued: entry.time, queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(), end: Instant::now(),
@ -202,6 +203,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct InferResponse { pub(crate) struct InferResponse {
pub(crate) output: String, pub(crate) output: String,
pub(crate) tokens: u32,
pub(crate) queued: Instant, pub(crate) queued: Instant,
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) end: Instant, pub(crate) end: Instant,

View File

@ -116,7 +116,7 @@ async fn generate(
let validation_time = response.queued - start_time; let validation_time = response.queued - start_time;
let queue_time = response.start - response.queued; let queue_time = response.start - response.queued;
let inference_time = response.end - response.start; let inference_time = response.end - response.start;
let time_per_token = inference_time / req.parameters.max_new_tokens; let time_per_token = inference_time / response.tokens;
// Headers // Headers
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();

View File

@ -1,7 +1,8 @@
from text_generation.models.model import Model from text_generation.models.model import Model
from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation.models.bloom import BLOOMSharded
from text_generation.models.causal_lm import CausalLM
__all__ = ["Model", "BLOOM", "BLOOMSharded"] __all__ = ["Model", "BLOOMSharded", "CausalLM"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
@ -11,6 +12,10 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
else: else:
if quantize: if quantize:
raise ValueError("quantization is not supported for non-sharded BLOOM") raise ValueError("quantization is not supported for non-sharded BLOOM")
return BLOOM(model_name) return CausalLM(model_name)
else: else:
raise ValueError(f"model {model_name} is not supported yet") if sharded:
raise ValueError("sharded is not supported for AutoModel")
if quantize:
raise ValueError("quantize is not supported for AutoModel")
return CausalLM(model_name)

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Tuple, Type from typing import List, Optional
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
@ -11,10 +11,8 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import Batch, GeneratedText
from text_generation.utils import ( from text_generation.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -31,322 +29,26 @@ except Exception as e:
torch.manual_seed(0) torch.manual_seed(0)
class BloomBatch(Batch): class BLOOMSharded(Model):
@classmethod
def concatenate(cls, batches: List["Batch"]) -> "BloomBatch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length:
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:]
for j, past in enumerate(batch.input_ids["past_key_values"]):
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
# Reshape the tensors to make slicing easier
past_keys = past_keys.view(
batch.size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch.size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
# Initialize tensors
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1):
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1):, :
] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
# If we are on the last batch, we need to reshape the tensors
if (i + 1) == len(batches):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
class BLOOM(Model):
def __init__(self, model_name: str):
if not model_name.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported")
if torch.cuda.is_available():
self.device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None
).eval()
self.num_heads = self.model.config.num_attention_heads
@property
def batch_type(self) -> Type[BloomBatch]:
return BloomBatch
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
) -> CausalLMOutputWithPast:
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
def generate_token(
self, batch: BloomBatch
) -> Tuple[List[GeneratedText], Optional[BloomBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
next_batch_past_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(GeneratedText(request, output))
# add to the next batch
else:
next_batch_keep_indices.append(i)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices.extend(
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
next_batch_input_ids["past_key_values"] = [
(
keys[next_batch_past_keep_indices],
values[next_batch_past_keep_indices],
)
for keys, values in outputs["past_key_values"]
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = BloomBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, quantize: bool = False): def __init__(self, model_name: str, quantize: bool = False):
super(Model, self).__init__()
if not model_name.startswith("bigscience/bloom"): if not model_name.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_name} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.float16 dtype = torch.float16
else: else:
self.device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True model_name, slow_but_exact=False, tp_parallel=True
) )
config.pad_token_id = 3 config.pad_token_id = 3
self.num_heads = config.n_head // self.process_group.size()
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
@ -370,12 +72,14 @@ class BLOOMSharded(BLOOM):
model, model,
filenames, filenames,
quantize=quantize, quantize=quantize,
device=self.device, device=device,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
self.model = model.eval().to(dtype) self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(BLOOMSharded, self).__init__(tokenizer=tokenizer, num_heads=config.n_head // self.process_group.size(),
device=device)
@staticmethod @staticmethod
def load_weights( def load_weights(
@ -526,5 +230,4 @@ class BLOOMSharded(BLOOM):
torch.distributed.all_gather(logits, logits_shard, group=self.process_group) torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size) logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
outputs.logits = logits return logits, outputs.past_key_values
return outputs

View File

@ -0,0 +1,38 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional, Tuple, List
from text_generation.models import Model
class CausalLM(Model):
def __init__(self, model_name: str):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
).eval()
super(CausalLM, self).__init__(tokenizer=tokenizer, num_heads=self.model.config.num_attention_heads, device=device)
def forward(
self, input_ids, attention_mask, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values

View File

@ -1,19 +1,139 @@
import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional
from tokenizers import Tokenizer
from text_generation.models.types import Batch, GeneratedText from text_generation.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
@property def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device):
@abstractmethod self.tokenizer = tokenizer
def batch_type(self) -> Type[B]: self.num_heads = num_heads
raise NotImplementedError self.device = device
@abstractmethod @abstractmethod
def generate_token( def forward(self, input_ids, attention_mask, past_key_values: Optional = None) -> Tuple[torch.Tensor, List[Tuple]]:
self, batch: B
) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError raise NotImplementedError
def generate_token(
self, batch: Batch
) -> Tuple[List[GeneratedText], Optional[Batch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
logits, past = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(GeneratedText(request, output, stopping_criteria.current_tokens))
# add to the next batch
else:
next_batch_keep_indices.append(i)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_input_ids["past_key_values"] = [
[t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] for t in layer]
for layer in past
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = Batch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch

View File

@ -1,6 +1,5 @@
import torch import torch
from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict from typing import List, Dict
@ -51,7 +50,11 @@ class Batch:
do_sample=r.parameters.do_sample, do_sample=r.parameters.do_sample,
) )
) )
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) stopping_criterias.append(
StoppingCriteria(
eos_token_id=tokenizer.eos_token_id, max_new_tokens=r.max_new_tokens
)
)
input_ids = tokenizer( input_ids = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
@ -71,15 +74,171 @@ class Batch:
) )
@classmethod @classmethod
@abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError # Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Only needed for Seq2SeqLM
max_encoded_sequence_length = None
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length:
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length:]
for j, past in enumerate(batch.input_ids["past_key_values"]):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim, padded_sequence_length = past[0].shape[-2:]
num_heads = (
past[0]
.view(batch.size, -1, head_dim, padded_sequence_length)
.shape[1]
)
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
input_ids["past_key_values"].append([])
# Decoder past
for k, t in enumerate(past[:2]):
# Needed because BLOOM past shapes are not the same for keys and values
# Keys: [batch_size * num_heads, head_dim, seq_length]
# Values: [batch_size * num_heads, seq_length, head_dim]
head_dim_last = False
if t.shape[-2] == head_dim:
t = t.view(
batch.size, num_heads, head_dim, padded_sequence_length
)
padded_t_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
)
elif t.shape[-1] == head_dim:
head_dim_last = True
t = t.view(
batch.size, num_heads, padded_sequence_length, head_dim
)
padded_t_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
)
else:
raise ValueError(f"shape {t.shape} is not valid")
# Initialize tensors
# This will run only once per layer and per past tensor
if k == len(input_ids["past_key_values"][j]):
input_ids["past_key_values"][j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
# We slice the past keys and values to remove the padding from previous batches
if not head_dim_last:
input_ids["past_key_values"][j][k][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1):,
] = t[:, :, :, -(batch.max_sequence_length - 1):]
else:
input_ids["past_key_values"][j][k][
start_index:end_index,
:,
-(batch.max_sequence_length - 1):,
:,
] = t[:, :, -(batch.max_sequence_length - 1):, :]
# Seq2SeqLM specific past (encoder past)
for k, t in enumerate(past[2:]):
if max_encoded_sequence_length is None:
max_encoded_sequence_length = max(max(batch.all_input_lengths) for batch in batches)
batch_max_encoded_sequence_length = max(batch.all_input_lengths)
padded_t_shape = (total_batch_size, num_heads, max_encoded_sequence_length, head_dim)
idx = k + 2
# Initialize tensors
# This will run only once per layer and per past tensor
if idx == len(input_ids["past_key_values"][j]):
input_ids["past_key_values"][j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
input_ids["past_key_values"][j][idx][
start_index:end_index,
:,
-batch_max_encoded_sequence_length:,
:
] = t[:, :, -batch_max_encoded_sequence_length:, :]
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
@dataclass @dataclass
class GeneratedText: class GeneratedText:
request: generate_pb2.Request request: generate_pb2.Request
output: str output: str
tokens: int
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(request=self.request, output=self.output) return generate_pb2.GeneratedText(request=self.request, output=self.output, tokens=self.tokens)

View File

@ -27,7 +27,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ClearCacheResponse() return generate_pb2.ClearCacheResponse()
async def Generate(self, request, context): async def Generate(self, request, context):
batch = self.model.batch_type.from_pb(request.batch, self.model.tokenizer, self.model.device) batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)
generated_texts, next_batch = self.model.generate_token(batch) generated_texts, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
@ -51,7 +51,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batches.append(batch) batches.append(batch)
if len(batches) > 1: if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches) batch = Batch.concatenate(batches)
else: else:
batch = batches[0] batch = batches[0]

View File

@ -58,7 +58,8 @@ class NextTokenChooser:
class StoppingCriteria: class StoppingCriteria:
def __init__(self, max_new_tokens=20): def __init__(self, eos_token_id, max_new_tokens=20):
self.eos_token_id = eos_token_id
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.current_tokens = 0 self.current_tokens = 0
@ -66,6 +67,8 @@ class StoppingCriteria:
self.current_tokens += 1 self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens: if self.current_tokens >= self.max_new_tokens:
return True return True
if self.eos_token_id is not None and all_ids[-1] == self.eos_token_id:
return True
return False return False
@ -124,11 +127,18 @@ def download_weights(model_name, extension=".safetensors"):
filenames = weight_hub_files(model_name, extension) filenames = weight_hub_files(model_name, extension)
download_function = partial( download_function = partial(
hf_hub_download, repo_id=model_name, local_files_only=False hf_hub_download,
repo_id=model_name,
local_files_only=False,
) )
executor = ThreadPoolExecutor(max_workers=5) executor = ThreadPoolExecutor(max_workers=5)
futures = [executor.submit(download_function, filename=filename) for filename in filenames] futures = [
files = [file for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures))] executor.submit(download_function, filename=filename) for filename in filenames
]
files = [
file
for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return files return files