From 4562c16048ca8b84a3c3ec0658f068d6f51c1498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 1 Aug 2024 11:20:42 +0000 Subject: [PATCH] Use a block size of 1 for FlashInfer --- backends/v3/src/backend.rs | 13 ++++++++++++- router/src/infer/v2/scheduler.rs | 14 +++++++++++++- server/text_generation_server/models/globals.py | 16 +++++++++++----- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index d82355de..cdc3c314 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -40,7 +40,18 @@ impl BackendV3 { } else { false }; - let block_size = if flashdecoding { 256 } else { 16 }; + let flashinfer = if let Ok(flashinfer) = std::env::var("FLASH_INFER") { + matches!(flashinfer.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { + 256 + } else if flashinfer { + 1 + } else { + 16 + }; let queue = Queue::new( requires_padding, diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 3d6c36cf..99b3d986 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -45,7 +45,19 @@ impl BackendV2 { } else { false }; - let block_size = if flashdecoding { 256 } else { 16 }; + let flashinfer = if let Ok(flashinfer) = std::env::var("FLASH_INFER") { + matches!(flashinfer.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { + 256 + } else if flashinfer { + 1 + } else { + 16 + }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 42b43c87..3dc54145 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,16 +5,22 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master +MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None +# This is overridden by the cli +FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} +if FLASH_DECODING: + log_master(logger.info, "Using FLASH_DECODING") + FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} if FLASH_INFER: log_master(logger.info, "Using FLASH_INFER") -MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None -# This is overridden by the cli -FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} -BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: - log_master(logger.info, "Using FLASH_DECODING") + BLOCK_SIZE = 256 +elif FLASH_INFER: + BLOCK_SIZE = 1 +else: + BLOCK_SIZE = 16 cuda_graphs = os.getenv("CUDA_GRAPHS")