From 7a48a84784b74f9ea12cf04a3c4572c027cec2e4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 9 Aug 2024 16:41:17 +0200 Subject: [PATCH] Using an enum for flash backens (paged/flashdecoding/flashinfer) (#2385) * Using an enum for flash backens (paged/flashdecoding/flashinfer) * Early exit on server too. * Clippy. * Fix clippy and fmt. --- .gitignore | 1 + backends/v3/src/backend.rs | 16 ++++++---- docs/openapi.json | 2 +- docs/source/conceptual/quantization.md | 4 +-- launcher/src/main.rs | 2 +- router/src/infer/v2/scheduler.rs | 18 ++++++++---- router/src/lib.rs | 29 +++++++++++++++++++ .../layers/attention/common.py | 4 +-- .../layers/attention/cuda.py | 11 ++++--- .../layers/attention/rocm.py | 4 +-- .../models/flash_causal_lm.py | 11 ++++--- .../text_generation_server/models/globals.py | 14 ++++----- 12 files changed, 78 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 0de8b848..bd9d9125 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp data/ load_tests/*.json +server/fbgemmm diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 6b3e0526..68ddf00b 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{FinishReason, PrefillToken, Token}; +use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -35,12 +35,18 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) } else { - false + Attention::Paged + }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 }; - let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, diff --git a/docs/openapi.json b/docs/openapi.json index 9d281a48..ed9b0b96 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2080,4 +2080,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} \ No newline at end of file +} diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index a1ebe7e7..b7672a9f 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -9,7 +9,7 @@ We recommend using the official quantization scripts for creating your quants: 2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py) 3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md) -For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. +For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. ## Quantization with bitsandbytes, EETQ & fp8 @@ -69,4 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num- You can learn more about the quantization options by running `text-generation-server quantize --help`. If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). -You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). \ No newline at end of file +You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8acfda0c..a64b1d71 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1461,7 +1461,7 @@ fn main() -> Result<(), LauncherError> { if config.model_type == Some("gemma2".to_string()) { tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("FLASH_DECODING", "1"); + std::env::set_var("ATTENTION", "flashdecoding"); } let config: Config = config.into(); diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index cc333674..0e5fc8a3 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,10 +1,10 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, + Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, }; use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; +use crate::{Attention, FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -40,12 +40,18 @@ impl BackendV2 { generation_health: Arc, ) -> Self { // Infer shared state - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention + .parse() + .expect(&format!("Invalid attention was specified :`{attention}`")) } else { - false + Attention::Paged + }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 }; - let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/router/src/lib.rs b/router/src/lib.rs index a956b058..66738706 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -15,6 +15,35 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[derive(PartialEq)] +pub enum Attention { + Paged, + FlashDecoding, + FlashInfer, +} + +#[derive(Debug)] +pub struct ParseError; + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Cannot parse attention value") + } +} +impl std::error::Error for ParseError {} + +impl std::str::FromStr for Attention { + type Err = ParseError; + fn from_str(s: &str) -> Result { + match s { + "paged" => Ok(Attention::Paged), + "flashdecoding" => Ok(Attention::FlashDecoding), + "flashinfer" => Ok(Attention::FlashInfer), + _ => Err(ParseError), + } + } +} + #[derive(Clone, Deserialize, ToSchema)] pub(crate) struct VertexInstance { #[schema(example = "What is Deep Learning?")] diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index b986a082..f162230c 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER +from text_generation_server.models.globals import ATTENTION import torch from typing import Optional -if FLASH_DECODING or FLASH_INFER: +if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass class Seqlen: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 1b8e9209..d039e1e7 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,9 +1,8 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ( - FLASH_DECODING, + ATTENTION, BLOCK_SIZE, - FLASH_INFER, ) from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -27,7 +26,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING or FLASH_INFER: + if ATTENTION in {"flashdecoding", "flashinfer"}: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -76,7 +75,7 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import decode_state return decode_state.get().forward( @@ -85,7 +84,7 @@ def paged_attention( logits_soft_cap=softcap, sm_scale=softmax_scale, ) - elif FLASH_DECODING: + elif ATTENTION == "flashdecoding": max_q = 1 max_k = max_s import flash_attn_2_cuda @@ -219,7 +218,7 @@ except ImportError: SUPPORTS_WINDOWING = V2 -if FLASH_INFER: +if ATTENTION == "flashinfer": def attention( q, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 69e64162..16ce8d2b 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,7 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import ATTENTION from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master from loguru import logger @@ -28,7 +28,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: + if ATTENTION == "flashdecoding": shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 12aa7dcd..21b66a68 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -40,8 +40,7 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, - FLASH_DECODING, - FLASH_INFER, + ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, @@ -938,7 +937,7 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import ( create_prefill_state, create_decode_state, @@ -990,7 +989,7 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if FLASH_DECODING or FLASH_INFER: + if ATTENTION in {"flashdecoding", "flashinfer"}: self.kv_cache = [ ( torch.empty( @@ -1062,7 +1061,7 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import ( create_decode_state_cuda_graphs, ) @@ -1766,7 +1765,7 @@ class FlashCausalLM(Model): input_lengths: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: - if not FLASH_INFER: + if ATTENTION != "flashinfer": return nullcontext() from text_generation_server.layers.attention.flash_infer import ( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 42b43c87..b58a5b80 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,16 +5,16 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} -if FLASH_INFER: - log_master(logger.info, "Using FLASH_INFER") +ATTENTION = os.getenv("ATTENTION", "paged") +_expected = {"paged", "flashdecoding", "flashinfer"} +assert ( + ATTENTION in _expected +), f"Attention is not valid {ATTENTION}, expected {_expected}" +log_master(logger.info, f"Using Attention = {ATTENTION}") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli -FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} -BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 -if FLASH_DECODING: - log_master(logger.info, "Using FLASH_DECODING") +BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 cuda_graphs = os.getenv("CUDA_GRAPHS")