ROCm and sliding windows fixes (#2033)

* update vllm commit & fix models using sliding window

* update

* update commit

* fix bug where tunableop is bound to cuda graph even when cuda graph are disabled

* enable tunableop by default

* fix sliding window

* address review

* dead code

* precise comment

* is it flaky?
This commit is contained in:
fxmarty 2024-06-10 09:09:50 +02:00 committed by GitHub
parent bf3c813782
commit 9b3674d903
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 36 additions and 24 deletions

View File

@ -481,6 +481,7 @@ fn shard_manager(
rope_factor: Option<f32>,
max_total_tokens: usize,
max_batch_size: Option<usize>,
max_input_tokens: usize,
otlp_endpoint: Option<String>,
log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>,
@ -553,6 +554,10 @@ fn shard_manager(
shard_args.push(otlp_endpoint);
}
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
@ -1009,6 +1014,7 @@ fn spawn_shards(
args: &Args,
cuda_graphs: Vec<usize>,
max_total_tokens: usize,
max_input_tokens: usize,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
@ -1066,6 +1072,7 @@ fn spawn_shards(
rope_factor,
max_total_tokens,
max_batch_size,
max_input_tokens,
otlp_endpoint,
max_log_level,
status_sender,
@ -1540,6 +1547,7 @@ fn main() -> Result<(), LauncherError> {
&args,
cuda_graphs,
max_total_tokens,
max_input_tokens,
max_log_level,
shutdown.clone(),
&shutdown_receiver,

View File

@ -1,5 +1,5 @@
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \

View File

@ -42,6 +42,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
max_input_tokens: Optional[int] = None,
):
if sharded:
assert (
@ -98,6 +99,7 @@ def serve(
dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)

View File

@ -169,10 +169,8 @@ if ENGINE == "ck":
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
@ -204,10 +202,7 @@ elif ENGINE == "triton":
window_size_left=-1,
causal=True,
):
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
k,

View File

@ -14,10 +14,7 @@ def attention(
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention(
q,
k,

View File

@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.phi import Phi
from text_generation_server.utils.import_utils import SYSTEM
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
@ -257,6 +259,7 @@ def get_model(
speculate: Optional[int],
dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
if dtype is None:
@ -410,11 +413,15 @@ def get_model(
"Sharding is currently not supported with `exl2` quantization"
)
sliding_window = config_dict.get("sliding_window", -1)
if sliding_window != -1 and not SUPPORTS_WINDOWING:
logger.warning(
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_input_tokens > sliding_window
):
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
FLASH_ATTENTION = False
if model_type == MAMBA:
return Mamba(
@ -701,7 +708,6 @@ def get_model(
)
if model_type == MISTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashMistral(
model_id,
@ -724,7 +730,6 @@ def get_model(
)
if model_type == MIXTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashMixtral(
model_id,
@ -747,7 +752,6 @@ def get_model(
)
if model_type == STARCODER2:
sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashStarcoder2(
model_id,
@ -771,8 +775,7 @@ def get_model(
)
if model_type == QWEN2:
sliding_window = config_dict.get("sliding_window", -1)
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
if FLASH_ATTENTION:
return FlashQwen2(
model_id,
revision,

View File

@ -902,6 +902,8 @@ class FlashCausalLM(Model):
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
torch.cuda.tunable.enable()
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True)
@ -910,8 +912,11 @@ class FlashCausalLM(Model):
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else:
elif CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS
else:
# For seqlen = 1, we dispatch to LLMM1 kernel.
tuning_sequences = [2, 3, 4, 5, 6, 7]
tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,

View File

@ -199,6 +199,7 @@ def serve(
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
max_input_tokens: int,
):
async def serve_inner(
model_id: str,
@ -229,6 +230,7 @@ def serve(
speculate,
dtype,
trust_remote_code,
max_input_tokens,
)
except Exception:
logger.exception("Error when initializing model")