ROCm and sliding windows fixes (#2033)
* update vllm commit & fix models using sliding window * update * update commit * fix bug where tunableop is bound to cuda graph even when cuda graph are disabled * enable tunableop by default * fix sliding window * address review * dead code * precise comment * is it flaky?
This commit is contained in:
parent
bf3c813782
commit
9b3674d903
|
@ -481,6 +481,7 @@ fn shard_manager(
|
|||
rope_factor: Option<f32>,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
max_input_tokens: usize,
|
||||
otlp_endpoint: Option<String>,
|
||||
log_level: LevelFilter,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
|
@ -553,6 +554,10 @@ fn shard_manager(
|
|||
shard_args.push(otlp_endpoint);
|
||||
}
|
||||
|
||||
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
||||
shard_args.push("--max-input-tokens".to_string());
|
||||
shard_args.push(max_input_tokens.to_string());
|
||||
|
||||
// Copy current process env
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
|
@ -1009,6 +1014,7 @@ fn spawn_shards(
|
|||
args: &Args,
|
||||
cuda_graphs: Vec<usize>,
|
||||
max_total_tokens: usize,
|
||||
max_input_tokens: usize,
|
||||
max_log_level: LevelFilter,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
|
@ -1066,6 +1072,7 @@ fn spawn_shards(
|
|||
rope_factor,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_input_tokens,
|
||||
otlp_endpoint,
|
||||
max_log_level,
|
||||
status_sender,
|
||||
|
@ -1540,6 +1547,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
&args,
|
||||
cuda_graphs,
|
||||
max_total_tokens,
|
||||
max_input_tokens,
|
||||
max_log_level,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
||||
commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
|
||||
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0
|
||||
build-vllm-cuda:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
|
@ -19,5 +19,5 @@ build-vllm-rocm:
|
|||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
||||
install-vllm-rocm: build-vllm-rocm
|
||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
|
||||
|
|
|
@ -42,6 +42,7 @@ def serve(
|
|||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
otlp_endpoint: Optional[str] = None,
|
||||
max_input_tokens: Optional[int] = None,
|
||||
):
|
||||
if sharded:
|
||||
assert (
|
||||
|
@ -98,6 +99,7 @@ def serve(
|
|||
dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -169,10 +169,8 @@ if ENGINE == "ck":
|
|||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
|
@ -204,10 +202,7 @@ elif ENGINE == "triton":
|
|||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
k,
|
||||
|
|
|
@ -14,10 +14,7 @@ def attention(
|
|||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
)
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return ipex.llm.functional.varlen_attention(
|
||||
q,
|
||||
k,
|
||||
|
|
|
@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded
|
|||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.phi import Phi
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -257,6 +259,7 @@ def get_model(
|
|||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
global FLASH_ATTENTION
|
||||
if dtype is None:
|
||||
|
@ -410,11 +413,15 @@ def get_model(
|
|||
"Sharding is currently not supported with `exl2` quantization"
|
||||
)
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if sliding_window != -1 and not SUPPORTS_WINDOWING:
|
||||
logger.warning(
|
||||
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
|
||||
|
||||
if (
|
||||
(sliding_window is not None and sliding_window != -1)
|
||||
and not SUPPORTS_WINDOWING
|
||||
and max_input_tokens > sliding_window
|
||||
):
|
||||
raise ValueError(
|
||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||
)
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
if model_type == MAMBA:
|
||||
return Mamba(
|
||||
|
@ -701,7 +708,6 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == MISTRAL:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
|
@ -724,7 +730,6 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == MIXTRAL:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMixtral(
|
||||
model_id,
|
||||
|
@ -747,7 +752,6 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == STARCODER2:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if FLASH_ATTENTION:
|
||||
return FlashStarcoder2(
|
||||
model_id,
|
||||
|
@ -771,8 +775,7 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == QWEN2:
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashQwen2(
|
||||
model_id,
|
||||
revision,
|
||||
|
|
|
@ -902,6 +902,8 @@ class FlashCausalLM(Model):
|
|||
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||
):
|
||||
torch.cuda.tunable.enable()
|
||||
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
|
||||
|
@ -910,8 +912,11 @@ class FlashCausalLM(Model):
|
|||
int(val)
|
||||
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||
]
|
||||
else:
|
||||
elif CUDA_GRAPHS is not None:
|
||||
tuning_sequences = CUDA_GRAPHS
|
||||
else:
|
||||
# For seqlen = 1, we dispatch to LLMM1 kernel.
|
||||
tuning_sequences = [2, 3, 4, 5, 6, 7]
|
||||
|
||||
tunableop_filepath = os.path.join(
|
||||
HUGGINGFACE_HUB_CACHE,
|
||||
|
|
|
@ -199,6 +199,7 @@ def serve(
|
|||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
max_input_tokens: int,
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
|
@ -229,6 +230,7 @@ def serve(
|
|||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
|
|
Loading…
Reference in New Issue