add missing files

This commit is contained in:
Felix Marty 2024-06-27 13:30:40 +00:00
parent 770975fa81
commit 3760102077
2 changed files with 1146 additions and 0 deletions

View File

@ -0,0 +1,787 @@
import torch
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@dataclass
class CausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
past_key_values: Optional[List[Tuple]]
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(concat_text_chunks(r.input_chunks.chunks))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(input_len - 5)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices]
position_ids = self.position_ids[keep_indices]
self.attention_mask = self.attention_mask[
keep_indices,
-(self.padding_right_offset + max_input_length) : (
self.attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if self.keys_head_dim_last:
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
else:
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
del past_keys
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = input_ids
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding
total_batch_size = 0
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
input_ids = None
attention_mask = None
position_ids = None
past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if input_ids is None:
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
if position_ids is None:
position_ids = batch.position_ids.new_empty((total_batch_size, 1))
position_ids[start_index:end_index] = batch.position_ids
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif len(batch.past_key_values[0][0].shape) == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
start_index = end_index
first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
padded_past_values_shape = (
total_batch_size,
num_heads,
max_input_length - 1,
head_dim,
)
if batches[0].keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
else:
# seq_length is last for BLOOM
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_input_length - 1,
)
# Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)):
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
start_index = 0
for batch in batches:
past_keys = batch.past_key_values[j][0]
# Clear reference to the original tensor
batch.past_key_values[j][0] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last:
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :]
)
else:
# BLOOM case
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
past_keys[:, :, :, -past_seq_len:]
)
del past_keys
start_index = end_index
padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape
)
start_index = 0
for batch in batches:
past_values = batch.past_key_values[j][1]
# Clear reference to the original tensor
batch.past_key_values[j][1] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :]
)
del past_values
# Update values
start_index = end_index
past_key_values.append([padded_past_keys, padded_past_values])
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens,
)
def __len__(self):
return len(self.requests)
class TransformersCausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
]:
# Model Forward
kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs)
if isinstance(outputs, tuple):
outputs, speculative_logits = outputs
else:
speculative_logits = None
return outputs.logits, speculative_logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, speculative_logits, past = self.forward(
batch.input_ids,
attention_mask,
batch.position_ids,
batch.past_key_values,
)
# Results
generations: List[Generation] = []
stopped = True
# Speculation is not active for causal
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
accepted_ids,
)
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to all tokens
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_squeezed,
next_token_text,
)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + torch.log_softmax(
logits, -1
).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
if top_n_tokens > 0:
all_top_tokens = []
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Decrease right offset
batch.padding_right_offset -= 1
# Update position_ids
batch.position_ids = batch.position_ids[:, -1:] + 1
# Update past key values
batch.past_key_values = past
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -0,0 +1,359 @@
import torch
import time
import sys
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict, Any
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from text_generation_server.utils.import_utils import (
empty_cache,
synchronize,
get_free_memory,
)
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.layers.attention import reshape_and_cache
from transformers.cache_utils import Cache
from transformers.flash_attention_utils import _flash_supports_window_size
from flash_attn import flash_attn_varlen_func
from text_generation_server.layers.attention import paged_attention
from loguru import logger
# Why define it here?
BLOCK_SIZE: int = 16
def patch_everywhere(
attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None
):
"""
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
Args:
attribute_name (`str`):
The name of attribute to patch.
patch (`Any`):
The patch for the attribute.
module_name_prefix (`Optional[str]`, defaults to `None`):
If set, only module names starting with this prefix will be considered for patching.
"""
# sys.modules may be updated while being iterated over, hence the list copy.
for name in list(sys.modules):
module = sys.modules[name]
if module_name_prefix is not None and not name.startswith(module_name_prefix):
continue
if hasattr(module, attribute_name):
setattr(module, attribute_name, patch)
def _flash_attention_forward_patched(
query_states,
key_states,
value_states,
attention_mask,
query_length,
layer_idx: int,
dropout=0.0,
softmax_scale=None,
is_causal=False,
_flash_attn_uses_top_left_mask=False,
sliding_window=None,
cache_position=0,
**kwargs, #: Unpack[ExtraKwargs],
):
_flash_attn_uses_top_left_mask = True # TODO felix: fix rocm
if not _flash_attn_uses_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = is_causal and query_length != 1
print(f"causal: {causal}")
use_sliding_windows = (
_flash_supports_window_size
and sliding_window is not None
and cache_position > sliding_window
)
flash_kwargs = (
{"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
)
print(f"kwargs {kwargs.keys()}")
cu_seqlen_prefill = kwargs.get("cu_seqlen_prefill")
max_seq_lens = kwargs.get("max_seq_lens")
if cu_seqlen_prefill is not None:
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlen_prefill,
cu_seqlens_k=cu_seqlen_prefill,
max_seqlen_q=kwargs["max_s"],
max_seqlen_k=kwargs["max_s"],
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
# **kwargs,
**flash_kwargs,
)
else:
attn_output = torch.empty_like(query_states)
paged_attention(
attn_output,
query_states,
kwargs["kv_cache"][layer_idx][0],
kwargs["kv_cache"][layer_idx][1],
kwargs["kv_head_mapping"],
softmax_scale,
kwargs["block_tables"],
kwargs["input_lengths"],
kwargs["max_s"],
)
attn_output = attn_output.view(attn_output.shape[0], -1)
return attn_output
class PagedCache(Cache):
def __init__(self) -> None:
pass
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
kv_cache = cache_kwargs["kv_cache"]
reshape_and_cache(
key_states,
value_states,
kv_cache[layer_idx][0],
kv_cache[layer_idx][1],
cache_kwargs["slots"],
)
if cache_kwargs["cu_seqlen_prefill"] is not None:
return key_states, value_states
else:
return None, None
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
raise ValueError(
"PagedCache.get_seq_length should never be called, please open an issue."
)
def get_max_length(self) -> Optional[int]:
raise ValueError(
"PagedCache.get_max_length should never be called, please open an issue."
)
class TransformersFlashCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda:0") # TODO felix: fix support for accelerate
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation="flash_attention_2",
)
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda()
self.kv_cache = []
# TODO felix: make this more general.
self.num_layers = len(model.model.layers)
self.num_kv_heads = model.config.num_key_value_heads
self.head_size = model.config.hidden_size // model.config.num_attention_heads
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
empty_cache()
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
try:
self.init_kv_cache(
batch.num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
max_bt = batch.max_blocks
max_s = max_bt * BLOCK_SIZE
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks
)
del batch
self.init_kv_cache(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
return int(num_blocks * BLOCK_SIZE)
def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# NOTE: adapter_data: not supported
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
# TODO felix: support window attention
# if cu_seqlen_prefill is None and self.max_past() is not None:
# # In decode, not prefill, we're actually overwriting the KV-cache
# # in a circular buffer mode.
# # This makes sure the max_s for the decode pass is correct.
# max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
past_key_values=PagedCache(),
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cache_position=False,
return_dict=False,
)[0]
if lm_head_indices is not None:
logits = logits[lm_head_indices]
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
speculative_logits = None
return logits, speculative_logits