fix: only run test when cuda is available

This commit is contained in:
David Holtz 2024-10-09 18:43:53 +00:00
parent a8108bc0da
commit 1ddde382bd
1 changed files with 12 additions and 7 deletions

View File

@ -1,14 +1,12 @@
import torch
import pytest
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.layers.attention import KVCache
from text_generation_server.utils.import_utils import SYSTEM
def test_kvcache_memory():
if SYSTEM == "cuda":
kvcache_memory()
else:
pytest.skip("Test only runs on CUDA")
# only include this import when CUDA is available
if SYSTEM == "cuda":
from text_generation_server.layers.attention import KVCache
def kvcache_memory():
num_blocks = 8188
@ -40,5 +38,12 @@ def kvcache_memory():
assert kv_cache_memory_mb < 1025
# only include this test when CUDA is available
if SYSTEM == "cuda":
def test_kvcache_memory():
kvcache_memory()
if __name__ == "__main__":
test_kvcache_memory()