[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
This commit is contained in:
Nicolas Patry 2024-07-01 23:28:00 +02:00 committed by GitHub
parent 4f55f15840
commit 4327210e6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 223 additions and 75 deletions

View File

@ -39,7 +39,14 @@ impl SchedulerV2 {
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let queue = Queue::new(requires_padding, 16, window_size, speculate); // Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic

View File

@ -39,9 +39,15 @@ impl SchedulerV3 {
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,
16, block_size,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,

View File

@ -1,6 +1,8 @@
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
import os import os
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":

View File

@ -0,0 +1,44 @@
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional
if FLASH_DECODING:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(self, input_lengths):
self.input_lengths = input_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
def clamp(self, max):
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -1,5 +1,7 @@
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
@ -21,7 +23,14 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention( def paged_attention(
@ -32,7 +41,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -53,7 +62,8 @@ def paged_attention(
# #
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3] # block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
@ -62,58 +72,95 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
from vllm._C import ops if FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) # TODO fixme when flash contains the fix.
if use_v1: # Number of splits is not correctly handled
ops.paged_attention_v1( # by the current path
out, # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
out2 = flash_attn_2_cuda.varlen_fwd(
query, query,
key_cache, key_cache,
value_cache, value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None, None,
"auto", seqlen.cu_seqlen_q,
1.0, seqlen.cu_seqlen_k,
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
False, # return softmax
None, # generator
) )
return out2[0]
else: else:
# Run PagedAttention V2. input_lengths = seqlen.input_lengths
assert _PARTITION_SIZE % block_size == 0 from vllm._C import ops
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2( use_v1 = max_s <= 8192 and (
out, max_num_partitions == 1 or num_seqs * num_heads > 512
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
) )
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
return out
try: try:

View File

@ -55,7 +55,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int, max_s: int,
): ):
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
@ -66,7 +67,7 @@ def paged_attention(
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, cu_seqlen_q,
BLOCK_SIZE, BLOCK_SIZE,
max_s, max_s,
None, None,

View File

@ -1,6 +1,7 @@
import os import os
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from loguru import logger from loguru import logger
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
@ -26,7 +27,14 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention( def paged_attention(
@ -37,7 +45,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int, max_s: int,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -61,6 +70,7 @@ def paged_attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
@ -119,6 +129,7 @@ def paged_attention(
"auto", "auto",
1.0, 1.0,
) )
return out
if ENGINE != "triton": if ENGINE != "triton":

View File

@ -12,7 +12,6 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoXSharded
@ -92,6 +92,7 @@ except ImportError as e:
FLASH_ATTENTION = False FLASH_ATTENTION = False
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(FlashGPT2) __all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
slots,
input_lengths, input_lengths,
slots,
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module):
) )
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,

View File

@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen,
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = input_lengths.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(

View File

@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,

View File

@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],

View File

@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],
@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],

View File

@ -30,10 +30,13 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
MEM_POOL, MEM_POOL,
FLASH_DECODING,
BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
get_adapter_to_index, get_adapter_to_index,
MODEL_ID, MODEL_ID,
) )
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
@ -46,7 +49,6 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
BLOCK_SIZE: int = 16
# Will be set in init # Will be set in init
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
@ -856,7 +858,23 @@ class FlashCausalLM(Model):
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if SYSTEM == "ipex" and device == torch.device("cpu"): if FLASH_DECODING:
self.kv_cache = [
(
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.empty(
@ -908,6 +926,7 @@ class FlashCausalLM(Model):
"slots": slots, "slots": slots,
"input_lengths": input_lengths, "input_lengths": input_lengths,
} }
input_lengths = Seqlen(input_lengths=input_lengths)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
@ -1067,6 +1086,7 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
input_lengths = Seqlen(input_lengths=input_lengths)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
@ -1153,6 +1173,7 @@ class FlashCausalLM(Model):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,

View File

@ -5,6 +5,12 @@ from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
logger.info("Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
try: try:
@ -15,8 +21,6 @@ if cuda_graphs is not None:
) )
else: else:
cuda_graphs = None cuda_graphs = None
# sorting the cuda graphs in descending order helps reduce the # sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage # memory impact and results in less memory usage
if cuda_graphs is not None: if cuda_graphs is not None: