fix: rerun black lint
This commit is contained in:
parent
301a18c2e5
commit
130f9d16b5
|
@ -10,11 +10,7 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
def test_position_rotary_embedding_static_basic():
|
def test_position_rotary_embedding_static_basic():
|
||||||
config = Mock(
|
config = Mock(rope_theta=10000, max_position_embeddings=2048, rope_scaling=None)
|
||||||
rope_theta=10000,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
rope_scaling=None
|
|
||||||
)
|
|
||||||
weights = Mock(device=torch.device("cpu"))
|
weights = Mock(device=torch.device("cpu"))
|
||||||
|
|
||||||
result = PositionRotaryEmbedding.static(
|
result = PositionRotaryEmbedding.static(
|
||||||
|
@ -30,10 +26,7 @@ def test_position_rotary_embedding_static_basic():
|
||||||
|
|
||||||
|
|
||||||
def test_position_rotary_embedding_static_linear_scaling():
|
def test_position_rotary_embedding_static_linear_scaling():
|
||||||
config = Mock(
|
config = Mock(rope_theta=10000, max_position_embeddings=2048)
|
||||||
rope_theta=10000,
|
|
||||||
max_position_embeddings=2048
|
|
||||||
)
|
|
||||||
# scaling is not applied if type is linear (TODO: maybe revisit this)
|
# scaling is not applied if type is linear (TODO: maybe revisit this)
|
||||||
config.rope_scaling = {"type": "linear", "factor": 2.0}
|
config.rope_scaling = {"type": "linear", "factor": 2.0}
|
||||||
weights = Mock(device=torch.device("cpu"))
|
weights = Mock(device=torch.device("cpu"))
|
||||||
|
@ -53,7 +46,7 @@ def test_position_rotary_embedding_static_dynamic_scaling():
|
||||||
config = Mock(
|
config = Mock(
|
||||||
rope_theta=10000,
|
rope_theta=10000,
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
rope_scaling = {"type": "dynamic", "factor": 2.0}
|
rope_scaling={"type": "dynamic", "factor": 2.0},
|
||||||
)
|
)
|
||||||
weights = Mock(device=torch.device("cpu"))
|
weights = Mock(device=torch.device("cpu"))
|
||||||
|
|
||||||
|
@ -73,11 +66,11 @@ def test_position_rotary_embedding_static_yarn_scaling():
|
||||||
config = Mock(
|
config = Mock(
|
||||||
rope_theta=10000,
|
rope_theta=10000,
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
rope_scaling = {
|
rope_scaling={
|
||||||
"type": "yarn",
|
"type": "yarn",
|
||||||
"factor": 1.5,
|
"factor": 1.5,
|
||||||
"original_max_position_embeddings": 2048,
|
"original_max_position_embeddings": 2048,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
weights = Mock(device=torch.device("cpu"))
|
weights = Mock(device=torch.device("cpu"))
|
||||||
|
|
||||||
|
@ -97,7 +90,7 @@ def test_position_rotary_embedding_static_invalid_scaling():
|
||||||
config = Mock(
|
config = Mock(
|
||||||
rope_theta=10000,
|
rope_theta=10000,
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
rope_scaling = {"type": "invalid", "factor": 2.0}
|
rope_scaling={"type": "invalid", "factor": 2.0},
|
||||||
)
|
)
|
||||||
weights = Mock(device=torch.device("cpu"))
|
weights = Mock(device=torch.device("cpu"))
|
||||||
|
|
||||||
|
@ -114,13 +107,14 @@ def test_position_rotary_embedding_static_llama3_scaling():
|
||||||
config = Mock(
|
config = Mock(
|
||||||
rope_theta=10000,
|
rope_theta=10000,
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
rope_scaling = {
|
rope_scaling={
|
||||||
"rope_type": "llama3",
|
"rope_type": "llama3",
|
||||||
"factor": 2.0,
|
"factor": 2.0,
|
||||||
"low_freq_factor": 4,
|
"low_freq_factor": 4,
|
||||||
"high_freq_factor": 32,
|
"high_freq_factor": 32,
|
||||||
"original_max_position_embeddings": 2048,
|
"original_max_position_embeddings": 2048,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
weights = Mock(device=torch.device("cpu"))
|
weights = Mock(device=torch.device("cpu"))
|
||||||
|
|
||||||
result = PositionRotaryEmbedding.static(
|
result = PositionRotaryEmbedding.static(
|
||||||
|
@ -158,8 +152,10 @@ def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings():
|
||||||
assert result.scaling_factor == 2.0
|
assert result.scaling_factor == 2.0
|
||||||
assert result.max_position_embeddings == 4096
|
assert result.max_position_embeddings == 4096
|
||||||
|
|
||||||
|
|
||||||
# Test the application of the rotary embedding
|
# Test the application of the rotary embedding
|
||||||
|
|
||||||
|
|
||||||
def position_rotary_embedding_no_rope_config():
|
def position_rotary_embedding_no_rope_config():
|
||||||
head_dim = 64
|
head_dim = 64
|
||||||
base = 10000
|
base = 10000
|
||||||
|
@ -174,7 +170,7 @@ def position_rotary_embedding_no_rope_config():
|
||||||
config = Mock(
|
config = Mock(
|
||||||
rope_theta=base,
|
rope_theta=base,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
rope_scaling=None
|
rope_scaling=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# create PositionRotaryEmbedding instance
|
# create PositionRotaryEmbedding instance
|
||||||
|
@ -232,7 +228,7 @@ def position_rotary_embedding_with_dynamic_scaling():
|
||||||
config = Mock(
|
config = Mock(
|
||||||
rope_theta=base,
|
rope_theta=base,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
rope_scaling={"type": "dynamic", "factor": 1.0}
|
rope_scaling={"type": "dynamic", "factor": 1.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
# create PositionRotaryEmbedding instance
|
# create PositionRotaryEmbedding instance
|
||||||
|
@ -275,7 +271,9 @@ def position_rotary_embedding_with_dynamic_scaling():
|
||||||
assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
|
assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
|
||||||
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
|
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
|
||||||
|
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
|
|
||||||
def test_position_rotary_embedding_with_dynamic_scaling():
|
def test_position_rotary_embedding_with_dynamic_scaling():
|
||||||
position_rotary_embedding_no_rope_config()
|
position_rotary_embedding_no_rope_config()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue