248 lines
7.6 KiB
Python
248 lines
7.6 KiB
Python
import pytest
|
|
from unittest.mock import Mock
|
|
from text_generation_server.utils.adapter import (
|
|
get_attn_weights,
|
|
get_mlp_weights,
|
|
parse_lora_adapters,
|
|
AdapterInfo,
|
|
)
|
|
|
|
|
|
def test_parse_lora_adapters_empty():
|
|
assert parse_lora_adapters(None) == []
|
|
assert parse_lora_adapters("") == []
|
|
|
|
|
|
def test_parse_lora_adapters_single():
|
|
result = parse_lora_adapters("adapter1")
|
|
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
|
|
|
|
|
|
def test_parse_lora_adapters_with_path():
|
|
result = parse_lora_adapters("adapter1=path/to/adapter1")
|
|
assert result == [
|
|
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
|
|
]
|
|
|
|
|
|
def test_parse_lora_adapters_with_path_and_revision():
|
|
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
|
|
assert result == [
|
|
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
|
|
]
|
|
|
|
|
|
def test_parse_lora_adapters_multiple():
|
|
result = parse_lora_adapters(
|
|
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
|
|
)
|
|
assert result == [
|
|
AdapterInfo(id="adapter1", path=None, revision=None),
|
|
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
|
|
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
|
|
]
|
|
|
|
|
|
def test_parse_lora_adapters_invalid_format():
|
|
try:
|
|
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
|
|
assert False, "Should have raised ValueError"
|
|
except ValueError as e:
|
|
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
|
|
|
|
|
|
def test_get_attn_weights():
|
|
# create a mock layer
|
|
mock_layer = Mock()
|
|
mock_layer.self_attn.query_key_value = Mock()
|
|
mock_layer.self_attn.o_proj = Mock()
|
|
|
|
# call the function
|
|
result = get_attn_weights(2, mock_layer)
|
|
|
|
# assert the result
|
|
expected = {
|
|
(2, "q_proj"): (
|
|
"model.layers.2.self_attn.q_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "k_proj"): (
|
|
"model.layers.2.self_attn.k_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "qkv_proj"): (
|
|
"model.layers.2.self_attn.qkv_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "v_proj"): (
|
|
"model.layers.2.self_attn.v_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
|
}
|
|
assert result == expected
|
|
|
|
|
|
def test_get_mlp_weights_with_gate_up_proj():
|
|
# create a mock layer with gate_up_proj
|
|
mock_layer = Mock()
|
|
mock_layer.mlp.gate_up_proj = Mock()
|
|
mock_layer.mlp.down_proj = Mock()
|
|
|
|
# call the function
|
|
result = get_mlp_weights(3, mock_layer)
|
|
|
|
# assert the result
|
|
expected = {
|
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
|
}
|
|
assert result == expected
|
|
|
|
|
|
def test_get_mlp_weights_without_gate_up_proj():
|
|
# create a mock layer without gate_up_proj
|
|
mock_layer = Mock()
|
|
mock_layer.mlp = Mock(spec=[])
|
|
|
|
# call the function
|
|
result = get_mlp_weights(1, mock_layer)
|
|
|
|
# assert the result
|
|
assert result == {}
|
|
|
|
|
|
@pytest.mark.parametrize("layer_index", [0, 1, 5])
|
|
def test_get_attn_weights_different_layers(layer_index):
|
|
mock_layer = Mock()
|
|
mock_layer.self_attn.query_key_value = Mock()
|
|
mock_layer.self_attn.o_proj = Mock()
|
|
|
|
result = get_attn_weights(layer_index, mock_layer)
|
|
|
|
for k in ["q", "k", "v"]:
|
|
assert (layer_index, f"{k}_proj") in result
|
|
assert (
|
|
result[(layer_index, f"{k}_proj")][0]
|
|
== f"model.layers.{layer_index}.self_attn.{k}_proj"
|
|
)
|
|
|
|
assert (layer_index, "o_proj") in result
|
|
assert (
|
|
result[(layer_index, "o_proj")][0]
|
|
== f"model.layers.{layer_index}.self_attn.o_proj"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("layer_index", [0, 1, 5])
|
|
def test_get_mlp_weights_different_layers(layer_index):
|
|
mock_layer = Mock()
|
|
mock_layer.mlp.gate_up_proj = Mock()
|
|
mock_layer.mlp.down_proj = Mock()
|
|
|
|
result = get_mlp_weights(layer_index, mock_layer)
|
|
|
|
for k in ["gate", "up", "down"]:
|
|
assert (layer_index, f"{k}_proj") in result
|
|
assert (
|
|
result[(layer_index, f"{k}_proj")][0]
|
|
== f"model.layers.{layer_index}.mlp.{k}_proj"
|
|
)
|
|
|
|
|
|
def test_get_attn_weights_llama_compatibility():
|
|
mock_layer = Mock()
|
|
mock_layer.self_attn.query_key_value = Mock()
|
|
mock_layer.self_attn.o_proj = Mock()
|
|
|
|
result = get_attn_weights(2, mock_layer)
|
|
|
|
expected = {
|
|
(2, "q_proj"): (
|
|
"model.layers.2.self_attn.q_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "k_proj"): (
|
|
"model.layers.2.self_attn.k_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "qkv_proj"): (
|
|
"model.layers.2.self_attn.qkv_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "v_proj"): (
|
|
"model.layers.2.self_attn.v_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
|
}
|
|
assert result == expected
|
|
|
|
|
|
def test_get_mlp_weights_llama_compatibility():
|
|
mock_layer = Mock()
|
|
mock_layer.mlp.gate_up_proj = Mock()
|
|
mock_layer.mlp.down_proj = Mock()
|
|
|
|
result = get_mlp_weights(3, mock_layer)
|
|
|
|
expected = {
|
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_up_proj),
|
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.gate_up_proj),
|
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
|
}
|
|
assert result == expected
|
|
|
|
|
|
def test_get_attn_weights_gemma_compatibility():
|
|
mock_layer = Mock()
|
|
mock_layer.self_attn.query_key_value = Mock()
|
|
mock_layer.self_attn.o_proj = Mock()
|
|
|
|
result = get_attn_weights(2, mock_layer)
|
|
|
|
expected = {
|
|
(2, "q_proj"): (
|
|
"model.layers.2.self_attn.q_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "k_proj"): (
|
|
"model.layers.2.self_attn.k_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "qkv_proj"): (
|
|
"model.layers.2.self_attn.qkv_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "v_proj"): (
|
|
"model.layers.2.self_attn.v_proj",
|
|
mock_layer.self_attn.query_key_value,
|
|
),
|
|
(2, "o_proj"): ("model.layers.2.self_attn.o_proj", mock_layer.self_attn.o_proj),
|
|
}
|
|
assert result == expected
|
|
|
|
|
|
def test_get_mlp_weights_gemma_compatibility():
|
|
mock_layer = Mock()
|
|
mock_layer.mlp.gate_proj = Mock()
|
|
mock_layer.mlp.up_proj = Mock()
|
|
mock_layer.mlp.down_proj = Mock()
|
|
|
|
# ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.
|
|
# This is necessary because the use of `Mock` automatically creates any
|
|
# attributes that are accessed, even if they don't exist in the actual
|
|
# implementation. If `gate_up_proj` were created, `get_mlp_weights` might
|
|
# follow the wrong execution path and return an incorrect result.
|
|
del mock_layer.mlp.gate_up_proj
|
|
|
|
result = get_mlp_weights(3, mock_layer)
|
|
|
|
expected = {
|
|
(3, "gate_proj"): ("model.layers.3.mlp.gate_proj", mock_layer.mlp.gate_proj),
|
|
(3, "up_proj"): ("model.layers.3.mlp.up_proj", mock_layer.mlp.up_proj),
|
|
(3, "down_proj"): ("model.layers.3.mlp.down_proj", mock_layer.mlp.down_proj),
|
|
}
|
|
assert result == expected
|