Use a block size of 1 for FlashInfer

This commit is contained in:
Daniël de Kok 2024-08-01 11:20:42 +00:00
parent 8fb8e1da78
commit 4562c16048
3 changed files with 36 additions and 7 deletions

View File

@ -40,7 +40,18 @@ impl BackendV3 {
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let flashinfer = if let Ok(flashinfer) = std::env::var("FLASH_INFER") {
matches!(flashinfer.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding {
256
} else if flashinfer {
1
} else {
16
};
let queue = Queue::new(
requires_padding,

View File

@ -45,7 +45,19 @@ impl BackendV2 {
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let flashinfer = if let Ok(flashinfer) = std::env::var("FLASH_INFER") {
matches!(flashinfer.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding {
256
} else if flashinfer {
1
} else {
16
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());

View File

@ -5,16 +5,22 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
if FLASH_DECODING:
log_master(logger.info, "Using FLASH_DECODING")
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
if FLASH_INFER:
log_master(logger.info, "Using FLASH_INFER")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
log_master(logger.info, "Using FLASH_DECODING")
BLOCK_SIZE = 256
elif FLASH_INFER:
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS")