[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
This commit is contained in:
Nicolas Patry 2024-07-01 23:28:00 +02:00 committed by GitHub
parent 4f55f15840
commit 4327210e6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 223 additions and 75 deletions

View File

@ -39,7 +39,14 @@ impl SchedulerV2 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let queue = Queue::new(requires_padding, 16, window_size, speculate);
// Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic

View File

@ -39,9 +39,15 @@ impl SchedulerV3 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
16,
block_size,
window_size,
speculate,
max_batch_total_tokens,

View File

@ -1,6 +1,8 @@
from text_generation_server.utils.import_utils import SYSTEM
import os
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":

View File

@ -0,0 +1,44 @@
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional
if FLASH_DECODING:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(self, input_lengths):
self.input_lengths = input_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
def clamp(self, max):
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -1,5 +1,7 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
@ -21,7 +23,14 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
@ -32,7 +41,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -53,7 +62,8 @@ def paged_attention(
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
@ -62,9 +72,45 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
out2 = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
False, # return softmax
None, # generator
)
return out2[0]
else:
input_lengths = seqlen.input_lengths
from vllm._C import ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
ops.paged_attention_v1(
out,
@ -114,6 +160,7 @@ def paged_attention(
"auto",
1.0,
)
return out
try:

View File

@ -55,7 +55,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
@ -66,7 +67,7 @@ def paged_attention(
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
cu_seqlen_q,
BLOCK_SIZE,
max_s,
None,

View File

@ -1,6 +1,7 @@
import os
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from loguru import logger
major, minor = torch.cuda.get_device_capability()
@ -26,7 +27,14 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
@ -37,7 +45,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -61,6 +70,7 @@ def paged_attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
@ -119,6 +129,7 @@ def paged_attention(
"auto",
1.0,
)
return out
if ENGINE != "triton":

View File

@ -12,7 +12,6 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded
@ -92,6 +92,7 @@ except ImportError as e:
FLASH_ATTENTION = False
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module):
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
slots,
max_s,
):
qkv = self.query_key_value(hidden_states)
@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module):
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,

View File

@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
Seqlen,
paged_attention,
attention,
reshape_and_cache,
@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(

View File

@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model(
input_ids,

View File

@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
qkv[:, 0],
kv_cache[0],

View File

@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -30,10 +30,13 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import (
MEM_POOL,
FLASH_DECODING,
BLOCK_SIZE,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
@ -46,7 +49,6 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__)
BLOCK_SIZE: int = 16
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
@ -856,7 +858,23 @@ class FlashCausalLM(Model):
else:
x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device == torch.device("cpu"):
if FLASH_DECODING:
self.kv_cache = [
(
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
(
torch.empty(
@ -908,6 +926,7 @@ class FlashCausalLM(Model):
"slots": slots,
"input_lengths": input_lengths,
}
input_lengths = Seqlen(input_lengths=input_lengths)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
@ -1067,6 +1086,7 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
input_lengths = Seqlen(input_lengths=input_lengths)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
@ -1153,6 +1173,7 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,

View File

@ -5,6 +5,12 @@ from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
logger.info("Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
try:
@ -15,8 +21,6 @@ if cuda_graphs is not None:
)
else:
cuda_graphs = None
# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None: