Using an enum for flash backens (paged/flashdecoding/flashinfer) (#2385)
* Using an enum for flash backens (paged/flashdecoding/flashinfer) * Early exit on server too. * Clippy. * Fix clippy and fmt.
This commit is contained in:
parent
6e127dcc96
commit
7a48a84784
|
@ -18,3 +18,4 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
|||
|
||||
data/
|
||||
load_tests/*.json
|
||||
server/fbgemmm
|
||||
|
|
|
@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
|
|||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
|
@ -35,12 +35,18 @@ impl BackendV3 {
|
|||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
false
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
|
|
|
@ -2080,4 +2080,4 @@
|
|||
"description": "Hugging Face Text Generation Inference API"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ We recommend using the official quantization scripts for creating your quants:
|
|||
2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py)
|
||||
3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md)
|
||||
|
||||
For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest.
|
||||
For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest.
|
||||
|
||||
## Quantization with bitsandbytes, EETQ & fp8
|
||||
|
||||
|
@ -69,4 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num-
|
|||
You can learn more about the quantization options by running `text-generation-server quantize --help`.
|
||||
|
||||
If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration).
|
||||
You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf).
|
||||
You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf).
|
||||
|
|
|
@ -1461,7 +1461,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
if config.model_type == Some("gemma2".to_string()) {
|
||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||
std::env::set_var("FLASH_DECODING", "1");
|
||||
std::env::set_var("ATTENTION", "flashdecoding");
|
||||
}
|
||||
let config: Config = config.into();
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::v2::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
use crate::{Attention, FinishReason, PrefillToken, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
|
@ -40,12 +40,18 @@ impl BackendV2 {
|
|||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.expect(&format!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
false
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
|
|
|
@ -15,6 +15,35 @@ use tracing::warn;
|
|||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum Attention {
|
||||
Paged,
|
||||
FlashDecoding,
|
||||
FlashInfer,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ParseError;
|
||||
|
||||
impl std::fmt::Display for ParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Cannot parse attention value")
|
||||
}
|
||||
}
|
||||
impl std::error::Error for ParseError {}
|
||||
|
||||
impl std::str::FromStr for Attention {
|
||||
type Err = ParseError;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"paged" => Ok(Attention::Paged),
|
||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
||||
"flashinfer" => Ok(Attention::FlashInfer),
|
||||
_ => Err(ParseError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexInstance {
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from dataclasses import dataclass
|
||||
from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
if FLASH_DECODING or FLASH_INFER:
|
||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import (
|
||||
FLASH_DECODING,
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
FLASH_INFER,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
@ -27,7 +26,7 @@ def reshape_and_cache(
|
|||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if FLASH_DECODING or FLASH_INFER:
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
|
@ -76,7 +75,7 @@ def paged_attention(
|
|||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
if FLASH_INFER:
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flash_infer import decode_state
|
||||
|
||||
return decode_state.get().forward(
|
||||
|
@ -85,7 +84,7 @@ def paged_attention(
|
|||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
)
|
||||
elif FLASH_DECODING:
|
||||
elif ATTENTION == "flashdecoding":
|
||||
max_q = 1
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
|
@ -219,7 +218,7 @@ except ImportError:
|
|||
|
||||
SUPPORTS_WINDOWING = V2
|
||||
|
||||
if FLASH_INFER:
|
||||
if ATTENTION == "flashinfer":
|
||||
|
||||
def attention(
|
||||
q,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
|
@ -28,7 +28,7 @@ def reshape_and_cache(
|
|||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if FLASH_DECODING:
|
||||
if ATTENTION == "flashdecoding":
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
|
|
|
@ -40,8 +40,7 @@ from text_generation_server.models.types import (
|
|||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.globals import (
|
||||
MEM_POOL,
|
||||
FLASH_DECODING,
|
||||
FLASH_INFER,
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
CUDA_GRAPHS,
|
||||
get_adapter_to_index,
|
||||
|
@ -938,7 +937,7 @@ class FlashCausalLM(Model):
|
|||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
|
||||
if FLASH_INFER:
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
create_prefill_state,
|
||||
create_decode_state,
|
||||
|
@ -990,7 +989,7 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if FLASH_DECODING or FLASH_INFER:
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
|
@ -1062,7 +1061,7 @@ class FlashCausalLM(Model):
|
|||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
if FLASH_INFER:
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
)
|
||||
|
@ -1766,7 +1765,7 @@ class FlashCausalLM(Model):
|
|||
input_lengths: torch.Tensor,
|
||||
state: Optional[Any] = None,
|
||||
) -> ContextManager:
|
||||
if not FLASH_INFER:
|
||||
if ATTENTION != "flashinfer":
|
||||
return nullcontext()
|
||||
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
|
|
|
@ -5,16 +5,16 @@ from typing import Dict, Optional
|
|||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
|
||||
if FLASH_INFER:
|
||||
log_master(logger.info, "Using FLASH_INFER")
|
||||
ATTENTION = os.getenv("ATTENTION", "paged")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
if FLASH_DECODING:
|
||||
log_master(logger.info, "Using FLASH_DECODING")
|
||||
BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16
|
||||
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
|
|
Loading…
Reference in New Issue