Using an enum for flash backens (paged/flashdecoding/flashinfer) (#2385)

* Using an enum for flash backens (paged/flashdecoding/flashinfer)

* Early exit on server too.

* Clippy.

* Fix clippy and fmt.
This commit is contained in:
Nicolas Patry 2024-08-09 16:41:17 +02:00 committed by GitHub
parent 6e127dcc96
commit 7a48a84784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 78 additions and 38 deletions

1
.gitignore vendored
View File

@ -18,3 +18,4 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
data/ data/
load_tests/*.json load_tests/*.json
server/fbgemmm

View File

@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token}; use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
@ -35,12 +35,18 @@ impl BackendV3 {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { let attention = if let Ok(attention) = std::env::var("ATTENTION") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else { } else {
false Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
}; };
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,

View File

@ -1461,7 +1461,7 @@ fn main() -> Result<(), LauncherError> {
if config.model_type == Some("gemma2".to_string()) { if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage"); tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("FLASH_DECODING", "1"); std::env::set_var("ATTENTION", "flashdecoding");
} }
let config: Config = config.into(); let config: Config = config.into();

View File

@ -1,10 +1,10 @@
/// Batching and inference logic /// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::v2::queue::{Entry, Queue};
use crate::infer::{ use crate::infer::{
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
}; };
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token}; use crate::{Attention, FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@ -40,12 +40,18 @@ impl BackendV2 {
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { let attention = if let Ok(attention) = std::env::var("ATTENTION") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") attention
.parse()
.expect(&format!("Invalid attention was specified :`{attention}`"))
} else { } else {
false Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
}; };
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate); let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());

View File

@ -15,6 +15,35 @@ use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
#[derive(Clone, Deserialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance { pub(crate) struct VertexInstance {
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER from text_generation_server.models.globals import ATTENTION
import torch import torch
from typing import Optional from typing import Optional
if FLASH_DECODING or FLASH_INFER: if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass @dataclass
class Seqlen: class Seqlen:

View File

@ -1,9 +1,8 @@
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
FLASH_DECODING, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
FLASH_INFER,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
@ -27,7 +26,7 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
if FLASH_DECODING or FLASH_INFER: if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
@ -76,7 +75,7 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
if FLASH_INFER: if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import decode_state from text_generation_server.layers.attention.flash_infer import decode_state
return decode_state.get().forward( return decode_state.get().forward(
@ -85,7 +84,7 @@ def paged_attention(
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
) )
elif FLASH_DECODING: elif ATTENTION == "flashdecoding":
max_q = 1 max_q = 1
max_k = max_s max_k = max_s
import flash_attn_2_cuda import flash_attn_2_cuda
@ -219,7 +218,7 @@ except ImportError:
SUPPORTS_WINDOWING = V2 SUPPORTS_WINDOWING = V2
if FLASH_INFER: if ATTENTION == "flashinfer":
def attention( def attention(
q, q,

View File

@ -1,7 +1,7 @@
import os import os
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
@ -28,7 +28,7 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
if FLASH_DECODING: if ATTENTION == "flashdecoding":
shape = key_cache.shape shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value

View File

@ -40,8 +40,7 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
MEM_POOL, MEM_POOL,
FLASH_DECODING, ATTENTION,
FLASH_INFER,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
get_adapter_to_index, get_adapter_to_index,
@ -938,7 +937,7 @@ class FlashCausalLM(Model):
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
if FLASH_INFER: if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flash_infer import (
create_prefill_state, create_prefill_state,
create_decode_state, create_decode_state,
@ -990,7 +989,7 @@ class FlashCausalLM(Model):
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if FLASH_DECODING or FLASH_INFER: if ATTENTION in {"flashdecoding", "flashinfer"}:
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.empty(
@ -1062,7 +1061,7 @@ class FlashCausalLM(Model):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
if FLASH_INFER: if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flash_infer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
) )
@ -1766,7 +1765,7 @@ class FlashCausalLM(Model):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
if not FLASH_INFER: if ATTENTION != "flashinfer":
return nullcontext() return nullcontext()
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flash_infer import (

View File

@ -5,16 +5,16 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} ATTENTION = os.getenv("ATTENTION", "paged")
if FLASH_INFER: _expected = {"paged", "flashdecoding", "flashinfer"}
log_master(logger.info, "Using FLASH_INFER") assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli # This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
log_master(logger.info, "Using FLASH_DECODING")
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")