diff --git a/server/tests/models/test_flash_causal_lm.py b/server/tests/models/test_flash_causal_lm.py index 1e27a859..c5c4bc5b 100644 --- a/server/tests/models/test_flash_causal_lm.py +++ b/server/tests/models/test_flash_causal_lm.py @@ -23,9 +23,11 @@ set_adapter_to_index({}) if SYSTEM == "cuda": + def test_flash_causal_lm_warmup(): flash_causal_lm_warmup() + def flash_causal_lm_warmup(): revision = "main" quantize = None