From 301a18c2e5364bf9da1f7a38b4d5f1d75207a154 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Wed, 9 Oct 2024 18:35:29 +0000 Subject: [PATCH] fix: limit some tests to only run when cuda available --- server/tests/utils/test_rotary_emb.py | 36 ++++++++++++++++----------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/server/tests/utils/test_rotary_emb.py b/server/tests/utils/test_rotary_emb.py index 7a3877cc..1ade3c8b 100644 --- a/server/tests/utils/test_rotary_emb.py +++ b/server/tests/utils/test_rotary_emb.py @@ -6,12 +6,13 @@ from text_generation_server.layers.rotary import ( DynamicPositionRotaryEmbedding, YarnPositionRotaryEmbedding, ) +from text_generation_server.utils.import_utils import SYSTEM def test_position_rotary_embedding_static_basic(): config = Mock( - rope_theta=10000, - max_position_embeddings=2048, + rope_theta=10000, + max_position_embeddings=2048, rope_scaling=None ) weights = Mock(device=torch.device("cpu")) @@ -30,7 +31,7 @@ def test_position_rotary_embedding_static_basic(): def test_position_rotary_embedding_static_linear_scaling(): config = Mock( - rope_theta=10000, + rope_theta=10000, max_position_embeddings=2048 ) # scaling is not applied if type is linear (TODO: maybe revisit this) @@ -50,8 +51,8 @@ def test_position_rotary_embedding_static_linear_scaling(): def test_position_rotary_embedding_static_dynamic_scaling(): config = Mock( - rope_theta=10000, - max_position_embeddings=2048, + rope_theta=10000, + max_position_embeddings=2048, rope_scaling = {"type": "dynamic", "factor": 2.0} ) weights = Mock(device=torch.device("cpu")) @@ -70,7 +71,7 @@ def test_position_rotary_embedding_static_dynamic_scaling(): def test_position_rotary_embedding_static_yarn_scaling(): config = Mock( - rope_theta=10000, + rope_theta=10000, max_position_embeddings=2048, rope_scaling = { "type": "yarn", @@ -94,8 +95,8 @@ def test_position_rotary_embedding_static_yarn_scaling(): def test_position_rotary_embedding_static_invalid_scaling(): config = Mock( - rope_theta=10000, - max_position_embeddings=2048, + rope_theta=10000, + max_position_embeddings=2048, rope_scaling = {"type": "invalid", "factor": 2.0} ) weights = Mock(device=torch.device("cpu")) @@ -111,7 +112,7 @@ def test_position_rotary_embedding_static_invalid_scaling(): def test_position_rotary_embedding_static_llama3_scaling(): config = Mock( - rope_theta=10000, + rope_theta=10000, max_position_embeddings=2048, rope_scaling = { "rope_type": "llama3", @@ -159,7 +160,7 @@ def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings(): # Test the application of the rotary embedding -def test_position_rotary_embedding_no_rope_config(): +def position_rotary_embedding_no_rope_config(): head_dim = 64 base = 10000 max_position_embeddings = 2048 @@ -171,7 +172,7 @@ def test_position_rotary_embedding_no_rope_config(): dtype = torch.float16 config = Mock( - rope_theta=base, + rope_theta=base, max_position_embeddings=max_position_embeddings, rope_scaling=None ) @@ -217,7 +218,7 @@ def test_position_rotary_embedding_no_rope_config(): assert not torch.allclose(k_rotated, key), "key should be modified by rotation" -def test_position_rotary_embedding_with_dynamic_scaling(): +def position_rotary_embedding_with_dynamic_scaling(): head_dim = 64 base = 10000 max_position_embeddings = 2048 @@ -229,8 +230,8 @@ def test_position_rotary_embedding_with_dynamic_scaling(): dtype = torch.float16 config = Mock( - rope_theta=base, - max_position_embeddings=max_position_embeddings, + rope_theta=base, + max_position_embeddings=max_position_embeddings, rope_scaling={"type": "dynamic", "factor": 1.0} ) @@ -273,3 +274,10 @@ def test_position_rotary_embedding_with_dynamic_scaling(): assert k_rotated.shape == key.shape, "key shape should not change after rotation" assert not torch.allclose(q_rotated, query), "query should be modified by rotation" assert not torch.allclose(k_rotated, key), "key should be modified by rotation" + +if SYSTEM == "cuda": + def test_position_rotary_embedding_with_dynamic_scaling(): + position_rotary_embedding_no_rope_config() + + def test_position_rotary_embedding_no_rope_config(): + position_rotary_embedding_no_rope_config()