Use a block size of 1 for FlashInfer
This commit is contained in:
parent
8fb8e1da78
commit
4562c16048
|
@ -40,7 +40,18 @@ impl BackendV3 {
|
|||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let flashinfer = if let Ok(flashinfer) = std::env::var("FLASH_INFER") {
|
||||
matches!(flashinfer.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding {
|
||||
256
|
||||
} else if flashinfer {
|
||||
1
|
||||
} else {
|
||||
16
|
||||
};
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
|
|
|
@ -45,7 +45,19 @@ impl BackendV2 {
|
|||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let flashinfer = if let Ok(flashinfer) = std::env::var("FLASH_INFER") {
|
||||
matches!(flashinfer.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding {
|
||||
256
|
||||
} else if flashinfer {
|
||||
1
|
||||
} else {
|
||||
16
|
||||
};
|
||||
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
|
|
|
@ -5,16 +5,22 @@ from typing import Dict, Optional
|
|||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
if FLASH_DECODING:
|
||||
log_master(logger.info, "Using FLASH_DECODING")
|
||||
|
||||
FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"}
|
||||
if FLASH_INFER:
|
||||
log_master(logger.info, "Using FLASH_INFER")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
|
||||
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
|
||||
if FLASH_DECODING:
|
||||
log_master(logger.info, "Using FLASH_DECODING")
|
||||
BLOCK_SIZE = 256
|
||||
elif FLASH_INFER:
|
||||
BLOCK_SIZE = 1
|
||||
else:
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
|
|
Loading…
Reference in New Issue