working flash + paged through transformers

This commit is contained in:
Felix Marty 2024-06-27 12:39:36 +00:00
parent be2d38032a
commit cb37c551ab
4 changed files with 679 additions and 6 deletions

View File

@ -12,6 +12,7 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.causal_lm_ragged import CausalLMRagged
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
@ -588,7 +589,7 @@ def get_model(
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION:
if FLASH_ATTENTION and False:
return FlashLlama(
model_id,
revision,
@ -601,7 +602,8 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else:
return CausalLM(
logger.info("LOADING CAUSALLM!!!!!!!!!!!!!!!!!!")
return CausalLMRagged(
model_id,
revision,
quantize=quantize,

View File

@ -0,0 +1,630 @@
import torch
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.utils.import_utils import (
empty_cache,
synchronize,
get_free_memory,
)
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
from transformers.cache_utils import PagedCache
from loguru import logger
# Why define it here?
BLOCK_SIZE: int = 16
class CausalLMRagged(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda:0") # TODO felix: fix support for accelerate
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation="flash_attention_2",
)
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda()
self.kv_cache = []
self.num_layers = len(model.model.layers)
self.num_kv_heads = model.config.num_key_value_heads
self.head_size = model.config.hidden_size // model.config.num_attention_heads
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
empty_cache()
try:
self.init_kv_cache(
batch.num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
max_bt = batch.max_blocks
max_s = max_bt * BLOCK_SIZE
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks
)
del batch
self.init_kv_cache(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
return int(num_blocks * BLOCK_SIZE)
def init_kv_cache(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.kv_cache = []
empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
raise ValueError("Untested. Please open an issue")
else:
x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device == torch.device("cpu"):
raise ValueError("Untested. Please open an issue")
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
@property
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward(
self, batch: FlashCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# NOTE: adapter_data: not supported
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
# TODO felix: support window attention
# if cu_seqlen_prefill is None and self.max_past() is not None:
# # In decode, not prefill, we're actually overwriting the KV-cache
# # in a circular buffer mode.
# # This makes sure the max_s for the decode pass is correct.
# max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
past_key_values=PagedCache(),
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cache_position=False,
return_dict=False,
)[0]
if lm_head_indices is not None:
logits = logits[lm_head_indices]
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
speculative_logits = None
return logits, speculative_logits
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)
# adapter_meta = batch.adapter_meta
# if batch.speculative_ids is not None:
# B, speculative_length = batch.speculative_ids.shape
# new_length = speculative_length + 1
# adapter_indices = (
# adapter_meta.adapter_indices.unsqueeze(-1)
# .expand(B, new_length)
# .reshape(-1)
# )
# adapter_segments = adapter_meta.adapter_segments * new_length
# adapter_meta = AdapterBatchMetadata(
# adapter_indices=adapter_indices,
# adapter_set=adapter_meta.adapter_set,
# adapter_segments=adapter_segments,
# segment_indices=adapter_meta.segment_indices,
# )
# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
# adapter_data = AdapterBatchData.from_meta(
# adapter_meta,
# self.layer_to_adapter_weights,
# prefill,
# batch.prefill_head_indices,
# )
logger.info(f"batch.input_ids {batch.input_ids}")
out, speculative_logits = self.forward(batch)
logger.info(f"out {out.shape}")
logger.info(f"speculative_logits {speculative_logits}")
if prefill:
next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices]
if prefill_logprobs
else speculative_logits
)
# next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
# len(batch)
# )
else:
next_token_logits = out
# next_adapter_indices = batch.adapter_meta.adapter_indices
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits,
speculate,
batch.speculative_ids,
speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None
else:
prefill_logprobs = None
next_position_ids = batch.position_ids
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
index = 0
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
# next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
# end_index - 1
# ]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
batch.input_ids[start_index + 1 : start_index + out_length]
)
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : start_index + out_length
]
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
cumulative_length += input_length
logger.info(f"batch.input_lengths_tensor {batch.input_lengths_tensor}")
logger.info(f"accepted_ids {accepted_ids}")
logger.info(f"batch.all_input_ids {batch.all_input_ids}")
# Update values
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
# batch.adapter_meta.adapter_indices = None
# if prefill:
# # adjust segment lengths to account for all request lengths being 1 during decoding
# adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
# batch.adapter_meta.adapter_segments = torch.tensor(
# adapter_segments,
# dtype=torch.int32,
# device=batch.adapter_meta.adapter_segments.device,
# )
if prefill and prefill_logprobs:
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
prefill_logprobs = torch.gather(
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
)
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.view(-1).tolist()
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.stopping_criterias,
batch.all_input_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
index = 0
for i, (
request,
input_length,
prefix_offset,
read_offset,
stopping_criteria,
all_input_ids,
do_sample,
seed,
top_n_tokens,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
next_token_texts = []
left = 0
if n_accepted_ids > 1:
if RANK == 0:
logger.debug(f"Speculated ids {n_accepted_ids - 1}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
else:
generated_text = None
# Prefill
if prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
out_start_index : out_end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
if top_n_tokens > 0:
all_top_tokens = []
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
)
# Update values
batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
if stopped:
# No need to return a batch if we know that all requests stopped
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -111,6 +111,7 @@ class FlashLlamaAttention(torch.nn.Module):
prefix: str,
config,
weights,
layer_idx,
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -143,6 +144,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.layer_idx = layer_idx
o_proj = TensorParallelRowLinear.load(
config,
@ -163,6 +165,8 @@ class FlashLlamaAttention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.step = 0
def forward(
self,
hidden_states,
@ -194,6 +198,18 @@ class FlashLlamaAttention(torch.nn.Module):
# output tensor
attn_output = torch.empty_like(query)
if self.layer_idx < 4:
torch.save(query, f"query_states_step{self.step}_layer{self.layer_idx}.pt")
if cu_seqlen_prefill is not None:
torch.save(
torch.select(kv, dim=1, index=0),
f"key_states_step{self.step}_layer{self.layer_idx}.pt",
)
torch.save(
torch.select(kv, dim=1, index=1),
f"value_states_step{self.step}_layer{self.layer_idx}.pt",
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
@ -220,9 +236,14 @@ class FlashLlamaAttention(torch.nn.Module):
max_s,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
attn_output = attn_output.view(-1, self.num_heads * self.head_size)
if self.layer_idx < 4:
torch.save(
attn_output, f"attn_output_step{self.step}_layer{self.layer_idx}.pt"
)
self.step += 1
return self.o_proj(attn_output, adapter_data)
class LlamaMLP(nn.Module):
@ -299,6 +320,7 @@ class LlamaMLP(nn.Module):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and False
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and not self.quantize
@ -320,13 +342,14 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights):
def __init__(self, index, prefix, config, weights, layer_idx):
super().__init__()
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
@ -399,6 +422,7 @@ class FlashLlamaModel(torch.nn.Module):
),
config=config,
weights=weights,
layer_idx=layer_id,
)
for layer_id in range(config.num_hidden_layers)
]

View File

@ -1149,6 +1149,23 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
logger.info(f"input_ids {input_ids} {input_ids.shape}")
logger.info(f"position_ids {position_ids} {position_ids.shape}")
logger.info(
f"cu_seqlen_prefill {cu_seqlen_prefill} {cu_seqlen_prefill.shape if cu_seqlen_prefill is not None else 'NONE'}"
)
logger.info(
f"kv_cache {type(kv_cache)}, len={len(kv_cache)}, {len(kv_cache[0])}, shape={kv_cache[0][0].shape}"
)
logger.info(
f"block_tables {type(block_tables)} {block_tables.shape} {block_tables}"
)
logger.info(f"slots {type(slots)} {slots.shape} {slots}")
logger.info(f"input_lengths {input_lengths}")
logger.info(f"max_s {max_s}")
logger.info(f"prefill_cache_indices {batch.prefill_cache_indices}")
logger.info(f"lm_head_indices {lm_head_indices}")
logger.info(f"adapter_data {adapter_data}")
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,