[Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940)
* Using flash decoding Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. Fix after rebase.. Less intrusive. REvert changes in modeling. Speedup flashdecoding. HHachweew Hack to make other models work. Fixing non flash decoding llama path. Router logic knows about page size. Missing 2 models. Missing cohere. Fixing cohere flash decoding. Revamped all this architecture. Fix cohere. Fixing falcon. Enabling custom block size schedule. Update router/src/infer.rs Not sending preallocated output. * Making it work on non flash decoding. * Fix Cohere. * Fix non decoding paths. * Rebased. * No need for cache_manager anymore. * Update? * "ipex" -> "cpu" * These do not belong. * Factoring cu_seqlen_qk for better abstracting over every model. * Fixing non flash tests/imports. * Changing return everywhere. * Update mistral past. * Fixing Mi{s,x}tral (non functional in Flash Decoding mode though). * Fixup mistral clamping (had issues with cuda graphs). * No need to recreate anything actually.
This commit is contained in:
parent
4f55f15840
commit
4327210e6b
|
@ -39,7 +39,14 @@ impl SchedulerV2 {
|
|||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||
// Infer shared state
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
|
|
|
@ -39,9 +39,15 @@ impl SchedulerV3 {
|
|||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
16,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
import os
|
||||
|
||||
from .common import Seqlen
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
from dataclasses import dataclass
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
if FLASH_DECODING:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
cu_seqlen_q: Optional[torch.Tensor]
|
||||
cu_seqlen_k: Optional[torch.Tensor]
|
||||
|
||||
def __init__(self, input_lengths):
|
||||
self.input_lengths = input_lengths
|
||||
device = self.input_lengths.device
|
||||
shape = self.input_lengths.shape
|
||||
cu_seqlen_q = torch.arange(
|
||||
shape[0] + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||
# Although FA2 might not want the clamping
|
||||
# cu_seqlen_k[0] = 0
|
||||
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
||||
|
||||
self.cu_seqlen_q = cu_seqlen_q
|
||||
self.cu_seqlen_k = cu_seqlen_k
|
||||
|
||||
def clamp(self, max):
|
||||
# Flash decoding doesn't need to clamp
|
||||
return self
|
||||
|
||||
else:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
|
||||
def clamp(self, max):
|
||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
|
@ -1,5 +1,7 @@
|
|||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
|
@ -21,7 +23,14 @@ def reshape_and_cache(
|
|||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
if FLASH_DECODING:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
|
@ -32,7 +41,7 @@ def paged_attention(
|
|||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
|
@ -53,7 +62,8 @@ def paged_attention(
|
|||
#
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# block_size = value_cache.shape[3]
|
||||
block_size = BLOCK_SIZE
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
|
||||
|
@ -62,9 +72,45 @@ def paged_attention(
|
|||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
if FLASH_DECODING:
|
||||
max_q = 1
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
|
||||
# TODO fixme when flash contains the fix.
|
||||
# Number of splits is not correctly handled
|
||||
# by the current path
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
||||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
||||
out2 = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
None,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0, # dropout
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
True, # causal
|
||||
-1, # Window_left
|
||||
-1, # Window right
|
||||
False, # return softmax
|
||||
None, # generator
|
||||
)
|
||||
return out2[0]
|
||||
else:
|
||||
input_lengths = seqlen.input_lengths
|
||||
from vllm._C import ops
|
||||
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
use_v1 = max_s <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
|
@ -114,6 +160,7 @@ def paged_attention(
|
|||
"auto",
|
||||
1.0,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
|
|
|
@ -55,7 +55,8 @@ def paged_attention(
|
|||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
cu_seqlen_q: torch.Tensor,
|
||||
cu_seqlen_k: torch.Tensor,
|
||||
max_s: int,
|
||||
):
|
||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
|
@ -66,7 +67,7 @@ def paged_attention(
|
|||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
cu_seqlen_q,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from loguru import logger
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
|
@ -26,7 +27,14 @@ def reshape_and_cache(
|
|||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
if FLASH_DECODING:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
|
@ -37,7 +45,8 @@ def paged_attention(
|
|||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
cu_seqlen_q: torch.Tensor,
|
||||
cu_seqlen_k: torch.Tensor,
|
||||
max_s: int,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
|
@ -61,6 +70,7 @@ def paged_attention(
|
|||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = cu_seqlen_k
|
||||
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
|
@ -119,6 +129,7 @@ def paged_attention(
|
|||
"auto",
|
||||
1.0,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
if ENGINE != "triton":
|
||||
|
|
|
@ -12,7 +12,6 @@ from pathlib import Path
|
|||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.mpt import MPTSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
|
@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|||
FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
|
@ -92,6 +92,7 @@ except ImportError as e:
|
|||
FLASH_ATTENTION = False
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(FlashGPT2)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRWSharded)
|
||||
|
|
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
|||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
slots,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module):
|
|||
)
|
||||
|
||||
residual = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
|
|
|
@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -33,6 +33,7 @@ from text_generation_server.layers.attention import (
|
|||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
|
@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -28,6 +28,7 @@ from typing import Optional, List, Tuple
|
|||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
|
@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
|
|
@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
|
|
@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
|
|
|
@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
|
|
@ -30,10 +30,13 @@ from text_generation_server.models.types import (
|
|||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.globals import (
|
||||
MEM_POOL,
|
||||
FLASH_DECODING,
|
||||
BLOCK_SIZE,
|
||||
CUDA_GRAPHS,
|
||||
get_adapter_to_index,
|
||||
MODEL_ID,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||
|
@ -46,7 +49,6 @@ from text_generation_server.utils.import_utils import (
|
|||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
|
@ -856,7 +858,23 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
if FLASH_DECODING:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
|
@ -908,6 +926,7 @@ class FlashCausalLM(Model):
|
|||
"slots": slots,
|
||||
"input_lengths": input_lengths,
|
||||
}
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
|
@ -1067,6 +1086,7 @@ class FlashCausalLM(Model):
|
|||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
|
@ -1153,6 +1173,7 @@ class FlashCausalLM(Model):
|
|||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
|
|
@ -5,6 +5,12 @@ from typing import Dict
|
|||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
if FLASH_DECODING:
|
||||
logger.info("Using FLASH_DECODING")
|
||||
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
|
@ -15,8 +21,6 @@ if cuda_graphs is not None:
|
|||
)
|
||||
else:
|
||||
cuda_graphs = None
|
||||
|
||||
|
||||
# sorting the cuda graphs in descending order helps reduce the
|
||||
# memory impact and results in less memory usage
|
||||
if cuda_graphs is not None:
|
||||
|
|
Loading…
Reference in New Issue