feat: add test for positional rotary embeddings
This commit is contained in:
parent
8ad20daf33
commit
040c5b5970
|
@ -0,0 +1,275 @@
|
|||
import pytest
|
||||
import torch
|
||||
from unittest.mock import Mock, patch
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
DynamicPositionRotaryEmbedding,
|
||||
YarnPositionRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
def test_position_rotary_embedding_static_basic():
|
||||
config = Mock(
|
||||
rope_theta=10000,
|
||||
max_position_embeddings=2048,
|
||||
rope_scaling=None
|
||||
)
|
||||
weights = Mock(device=torch.device("cpu"))
|
||||
|
||||
result = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=64,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
assert isinstance(result, PositionRotaryEmbedding)
|
||||
assert result.inv_freq.shape == (32,) # dim // 2
|
||||
assert result.scaling_factor is None
|
||||
|
||||
|
||||
def test_position_rotary_embedding_static_linear_scaling():
|
||||
config = Mock(
|
||||
rope_theta=10000,
|
||||
max_position_embeddings=2048
|
||||
)
|
||||
# scaling is not applied if type is linear (TODO: maybe revisit this)
|
||||
config.rope_scaling = {"type": "linear", "factor": 2.0}
|
||||
weights = Mock(device=torch.device("cpu"))
|
||||
|
||||
result = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=64,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
assert isinstance(result, PositionRotaryEmbedding)
|
||||
assert result.scaling_factor is None
|
||||
|
||||
|
||||
def test_position_rotary_embedding_static_dynamic_scaling():
|
||||
config = Mock(
|
||||
rope_theta=10000,
|
||||
max_position_embeddings=2048,
|
||||
rope_scaling = {"type": "dynamic", "factor": 2.0}
|
||||
)
|
||||
weights = Mock(device=torch.device("cpu"))
|
||||
|
||||
result = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=64,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
assert isinstance(result, DynamicPositionRotaryEmbedding)
|
||||
assert result.scaling_factor == 2.0
|
||||
assert result.max_position_embeddings == 2048
|
||||
|
||||
|
||||
def test_position_rotary_embedding_static_yarn_scaling():
|
||||
config = Mock(
|
||||
rope_theta=10000,
|
||||
max_position_embeddings=2048,
|
||||
rope_scaling = {
|
||||
"type": "yarn",
|
||||
"factor": 1.5,
|
||||
"original_max_position_embeddings": 2048,
|
||||
}
|
||||
)
|
||||
weights = Mock(device=torch.device("cpu"))
|
||||
|
||||
result = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=64,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
assert isinstance(result, YarnPositionRotaryEmbedding)
|
||||
assert result.scaling_factor == 1.5
|
||||
assert result.max_position_embeddings == 2048
|
||||
|
||||
|
||||
def test_position_rotary_embedding_static_invalid_scaling():
|
||||
config = Mock(
|
||||
rope_theta=10000,
|
||||
max_position_embeddings=2048,
|
||||
rope_scaling = {"type": "invalid", "factor": 2.0}
|
||||
)
|
||||
weights = Mock(device=torch.device("cpu"))
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=64,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
|
||||
def test_position_rotary_embedding_static_llama3_scaling():
|
||||
config = Mock(
|
||||
rope_theta=10000,
|
||||
max_position_embeddings=2048,
|
||||
rope_scaling = {
|
||||
"rope_type": "llama3",
|
||||
"factor": 2.0,
|
||||
"low_freq_factor": 4,
|
||||
"high_freq_factor": 32,
|
||||
"original_max_position_embeddings": 2048,
|
||||
})
|
||||
weights = Mock(device=torch.device("cpu"))
|
||||
|
||||
result = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=64,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
assert isinstance(result, PositionRotaryEmbedding)
|
||||
assert result.scaling_factor is None
|
||||
|
||||
|
||||
def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings():
|
||||
config = Mock(
|
||||
rope_theta=10000,
|
||||
max_position_embeddings=4096,
|
||||
rope_scaling=None,
|
||||
)
|
||||
weights = Mock(device=torch.device("cpu"))
|
||||
|
||||
with patch(
|
||||
"text_generation_server.layers.rotary._get_rope_config"
|
||||
) as mock_get_rope_config:
|
||||
mock_get_rope_config.return_value = {"type": "dynamic", "factor": 2.0}
|
||||
|
||||
result = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=64,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
assert isinstance(result, DynamicPositionRotaryEmbedding)
|
||||
assert result.scaling_factor == 2.0
|
||||
assert result.max_position_embeddings == 4096
|
||||
|
||||
# Test the application of the rotary embedding
|
||||
|
||||
def test_position_rotary_embedding_no_rope_config():
|
||||
head_dim = 64
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
num_heads = 16
|
||||
batch_size = 2
|
||||
seq_len = 128
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
|
||||
config = Mock(
|
||||
rope_theta=base,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_scaling=None
|
||||
)
|
||||
|
||||
# create PositionRotaryEmbedding instance
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=head_dim, base=base, device=device
|
||||
)
|
||||
|
||||
# generate position IDs
|
||||
position_ids = torch.arange(seq_len).unsqueeze(0)
|
||||
position_ids = position_ids.to(device).to(torch.int32).view(-1)
|
||||
|
||||
# get cos and sin values for the position IDs
|
||||
cos, sin = rotary_emb.get_cos_sin(
|
||||
position_ids=position_ids,
|
||||
max_s=seq_len,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# create query and key tensors
|
||||
query = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype)
|
||||
key = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype)
|
||||
|
||||
# clone to compare later
|
||||
original_query = query.clone()
|
||||
original_key = key.clone()
|
||||
|
||||
# apply rotary embedding
|
||||
rotary_emb(query, key, cos, sin)
|
||||
|
||||
# copy rotated query and key and original query and key
|
||||
q_rotated = query
|
||||
k_rotated = key
|
||||
query = original_query
|
||||
key = original_key
|
||||
|
||||
assert (
|
||||
q_rotated.shape == query.shape
|
||||
), "query shape should not change after rotation"
|
||||
assert k_rotated.shape == key.shape, "key shape should not change after rotation"
|
||||
assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
|
||||
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
|
||||
|
||||
|
||||
def test_position_rotary_embedding_with_dynamic_scaling():
|
||||
head_dim = 64
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
num_heads = 16
|
||||
batch_size = 2
|
||||
seq_len = 128
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
|
||||
config = Mock(
|
||||
rope_theta=base,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_scaling={"type": "dynamic", "factor": 1.0}
|
||||
)
|
||||
|
||||
# create PositionRotaryEmbedding instance
|
||||
rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=head_dim, base=base, device=device
|
||||
)
|
||||
|
||||
# generate position IDs
|
||||
position_ids = torch.arange(seq_len).unsqueeze(0)
|
||||
position_ids = position_ids.to(device).to(torch.int32).view(-1)
|
||||
|
||||
# get cos and sin values for the position IDs
|
||||
cos, sin = rotary_emb.get_cos_sin(
|
||||
position_ids=position_ids,
|
||||
max_s=seq_len,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# create query and key tensors
|
||||
query = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype)
|
||||
key = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype)
|
||||
|
||||
# clone to compare later
|
||||
original_query = query.clone()
|
||||
original_key = key.clone()
|
||||
|
||||
# apply rotary embedding
|
||||
rotary_emb(query, key, cos, sin)
|
||||
|
||||
# copy rotated query and key and original query and key
|
||||
q_rotated = query
|
||||
k_rotated = key
|
||||
query = original_query
|
||||
key = original_key
|
||||
|
||||
assert (
|
||||
q_rotated.shape == query.shape
|
||||
), "query shape should not change after rotation"
|
||||
assert k_rotated.shape == key.shape, "key shape should not change after rotation"
|
||||
assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
|
||||
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"
|
Loading…
Reference in New Issue