fix: limit some tests to only run when cuda available
This commit is contained in:
parent
040c5b5970
commit
301a18c2e5
|
@ -6,6 +6,7 @@ from text_generation_server.layers.rotary import (
|
||||||
DynamicPositionRotaryEmbedding,
|
DynamicPositionRotaryEmbedding,
|
||||||
YarnPositionRotaryEmbedding,
|
YarnPositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
def test_position_rotary_embedding_static_basic():
|
def test_position_rotary_embedding_static_basic():
|
||||||
|
@ -159,7 +160,7 @@ def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings():
|
||||||
|
|
||||||
# Test the application of the rotary embedding
|
# Test the application of the rotary embedding
|
||||||
|
|
||||||
def test_position_rotary_embedding_no_rope_config():
|
def position_rotary_embedding_no_rope_config():
|
||||||
head_dim = 64
|
head_dim = 64
|
||||||
base = 10000
|
base = 10000
|
||||||
max_position_embeddings = 2048
|
max_position_embeddings = 2048
|
||||||
|
@ -217,7 +218,7 @@ def test_position_rotary_embedding_no_rope_config():
|
||||||
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
|
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
|
||||||
|
|
||||||
|
|
||||||
def test_position_rotary_embedding_with_dynamic_scaling():
|
def position_rotary_embedding_with_dynamic_scaling():
|
||||||
head_dim = 64
|
head_dim = 64
|
||||||
base = 10000
|
base = 10000
|
||||||
max_position_embeddings = 2048
|
max_position_embeddings = 2048
|
||||||
|
@ -273,3 +274,10 @@ def test_position_rotary_embedding_with_dynamic_scaling():
|
||||||
assert k_rotated.shape == key.shape, "key shape should not change after rotation"
|
assert k_rotated.shape == key.shape, "key shape should not change after rotation"
|
||||||
assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
|
assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
|
||||||
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
|
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
def test_position_rotary_embedding_with_dynamic_scaling():
|
||||||
|
position_rotary_embedding_no_rope_config()
|
||||||
|
|
||||||
|
def test_position_rotary_embedding_no_rope_config():
|
||||||
|
position_rotary_embedding_no_rope_config()
|
||||||
|
|
Loading…
Reference in New Issue