fix: rerun black lint

This commit is contained in:
David Holtz 2024-10-09 18:44:41 +00:00
parent 301a18c2e5
commit 130f9d16b5
1 changed files with 15 additions and 17 deletions

View File

@ -10,11 +10,7 @@ from text_generation_server.utils.import_utils import SYSTEM
def test_position_rotary_embedding_static_basic(): def test_position_rotary_embedding_static_basic():
config = Mock( config = Mock(rope_theta=10000, max_position_embeddings=2048, rope_scaling=None)
rope_theta=10000,
max_position_embeddings=2048,
rope_scaling=None
)
weights = Mock(device=torch.device("cpu")) weights = Mock(device=torch.device("cpu"))
result = PositionRotaryEmbedding.static( result = PositionRotaryEmbedding.static(
@ -30,10 +26,7 @@ def test_position_rotary_embedding_static_basic():
def test_position_rotary_embedding_static_linear_scaling(): def test_position_rotary_embedding_static_linear_scaling():
config = Mock( config = Mock(rope_theta=10000, max_position_embeddings=2048)
rope_theta=10000,
max_position_embeddings=2048
)
# scaling is not applied if type is linear (TODO: maybe revisit this) # scaling is not applied if type is linear (TODO: maybe revisit this)
config.rope_scaling = {"type": "linear", "factor": 2.0} config.rope_scaling = {"type": "linear", "factor": 2.0}
weights = Mock(device=torch.device("cpu")) weights = Mock(device=torch.device("cpu"))
@ -53,7 +46,7 @@ def test_position_rotary_embedding_static_dynamic_scaling():
config = Mock( config = Mock(
rope_theta=10000, rope_theta=10000,
max_position_embeddings=2048, max_position_embeddings=2048,
rope_scaling = {"type": "dynamic", "factor": 2.0} rope_scaling={"type": "dynamic", "factor": 2.0},
) )
weights = Mock(device=torch.device("cpu")) weights = Mock(device=torch.device("cpu"))
@ -73,11 +66,11 @@ def test_position_rotary_embedding_static_yarn_scaling():
config = Mock( config = Mock(
rope_theta=10000, rope_theta=10000,
max_position_embeddings=2048, max_position_embeddings=2048,
rope_scaling = { rope_scaling={
"type": "yarn", "type": "yarn",
"factor": 1.5, "factor": 1.5,
"original_max_position_embeddings": 2048, "original_max_position_embeddings": 2048,
} },
) )
weights = Mock(device=torch.device("cpu")) weights = Mock(device=torch.device("cpu"))
@ -97,7 +90,7 @@ def test_position_rotary_embedding_static_invalid_scaling():
config = Mock( config = Mock(
rope_theta=10000, rope_theta=10000,
max_position_embeddings=2048, max_position_embeddings=2048,
rope_scaling = {"type": "invalid", "factor": 2.0} rope_scaling={"type": "invalid", "factor": 2.0},
) )
weights = Mock(device=torch.device("cpu")) weights = Mock(device=torch.device("cpu"))
@ -114,13 +107,14 @@ def test_position_rotary_embedding_static_llama3_scaling():
config = Mock( config = Mock(
rope_theta=10000, rope_theta=10000,
max_position_embeddings=2048, max_position_embeddings=2048,
rope_scaling = { rope_scaling={
"rope_type": "llama3", "rope_type": "llama3",
"factor": 2.0, "factor": 2.0,
"low_freq_factor": 4, "low_freq_factor": 4,
"high_freq_factor": 32, "high_freq_factor": 32,
"original_max_position_embeddings": 2048, "original_max_position_embeddings": 2048,
}) },
)
weights = Mock(device=torch.device("cpu")) weights = Mock(device=torch.device("cpu"))
result = PositionRotaryEmbedding.static( result = PositionRotaryEmbedding.static(
@ -158,8 +152,10 @@ def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings():
assert result.scaling_factor == 2.0 assert result.scaling_factor == 2.0
assert result.max_position_embeddings == 4096 assert result.max_position_embeddings == 4096
# Test the application of the rotary embedding # Test the application of the rotary embedding
def position_rotary_embedding_no_rope_config(): def position_rotary_embedding_no_rope_config():
head_dim = 64 head_dim = 64
base = 10000 base = 10000
@ -174,7 +170,7 @@ def position_rotary_embedding_no_rope_config():
config = Mock( config = Mock(
rope_theta=base, rope_theta=base,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
rope_scaling=None rope_scaling=None,
) )
# create PositionRotaryEmbedding instance # create PositionRotaryEmbedding instance
@ -232,7 +228,7 @@ def position_rotary_embedding_with_dynamic_scaling():
config = Mock( config = Mock(
rope_theta=base, rope_theta=base,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
rope_scaling={"type": "dynamic", "factor": 1.0} rope_scaling={"type": "dynamic", "factor": 1.0},
) )
# create PositionRotaryEmbedding instance # create PositionRotaryEmbedding instance
@ -275,7 +271,9 @@ def position_rotary_embedding_with_dynamic_scaling():
assert not torch.allclose(q_rotated, query), "query should be modified by rotation" assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
assert not torch.allclose(k_rotated, key), "key should be modified by rotation" assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
if SYSTEM == "cuda": if SYSTEM == "cuda":
def test_position_rotary_embedding_with_dynamic_scaling(): def test_position_rotary_embedding_with_dynamic_scaling():
position_rotary_embedding_no_rope_config() position_rotary_embedding_no_rope_config()