import pytest import torch from text_generation_server.utils.weights import ( DefaultWeightsLoader, Weights, WeightsLoader, ) from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader from text_generation_server.layers.marlin.marlin import ( MarlinWeight, MarlinWeightsLoader, ) from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path @pytest.fixture def gptq_weights_loader(): return GPTQWeightsLoader( bits=4, groupsize=-1, desc_act=False, quant_method="gptq", quantize="gptq", sym=True, ) @pytest.fixture def gptq_weights_loader_awq(): return GPTQWeightsLoader( bits=4, groupsize=-1, desc_act=False, quant_method="awq", quantize="awq", sym=True, ) @pytest.fixture def marlin_weights_loader(): return MarlinWeightsLoader(bits=4, is_marlin_24=False) dummy_file_system = { "test_weights": { "layer.0.weight": torch.tensor( [ [1, 2], [3, 4], ], dtype=torch.float32, ), }, "test_weights_2": { "layer.1337.weight": torch.tensor( [ [1, 2, 3, 4], [5, 6, 7, 8], ], dtype=torch.float32, ), }, "test_get_weights_col_packed": { "weight.weight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.float32, ), }, "test_get_multi_weights_col": { "weight.weight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.float32, ), }, "test_get_weights_row": { "weight.weight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.float32, ), }, "test_get_weights_col_gptq": { "weight.qweight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.float32, ), "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), "weight.qzeros": torch.tensor( [ [0, 1], [1, 0], ], dtype=torch.int32, ), "weight.scales": torch.tensor( [ [100.0, 100.0], [100.0, 100.0], ], dtype=torch.float16, ), "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, "test_get_weights_col_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), }, "test_get_weights_row_gptq": { "weight.qweight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.int32, ), "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), "weight.qzeros": torch.tensor( [ [0, 1], [1, 0], ], dtype=torch.int32, ), "weight.scales": torch.tensor( [ [100.0, 100.0], [100.0, 100.0], ], dtype=torch.float16, ), "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, "test_get_multi_weights_col_gptq": { "weight.qweight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.int32, ), "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), "weight.qzeros": torch.tensor( [ [0, 1], [1, 0], ], dtype=torch.int32, ), "weight.scales": torch.tensor( [ [100.0, 100.0], [100.0, 100.0], ], dtype=torch.float16, ), "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, "test_get_weights_col_packed_gptq": { "weight.qweight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.int32, ), "weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32), "weight.qzeros": torch.tensor( [ [0, 1], [1, 0], ], dtype=torch.int32, ), "weight.scales": torch.tensor( [ [100.0, 100.0], [100.0, 100.0], ], dtype=torch.float16, ), "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([2], dtype=torch.float32), }, "test_get_weights_col_packed_exl2": { "weight.q_weight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_weights_row_exl2": { "weight.q_weight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_multi_weights_col_exl2": { "weight.q_weight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_weights_col_exl2": { "weight.q_weight": torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32), "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_weights_row_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, "test_get_multi_weights_col_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, "test_get_weights_col_packed_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, } class MockSlice: def __init__(self, tensor): self.tensor = tensor def get_shape(self): return self.tensor.shape def __getitem__(self, idx): return self.tensor[idx] def mock_get_slice(tensor_name, filename): tensor = dummy_file_system[filename][tensor_name] return MockSlice(tensor) def mock_handle(filename, device, dtype): return SimpleNamespace( get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename) ) class MockSafeOpen: def __init__(self, filename, framework, dummy_fs): self.filename = filename self.framework = framework self.dummy_fs = dummy_fs def keys(self): return list(self.dummy_fs[self.filename].keys()) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass class MockWeights(Weights): def __init__( self, filenames: List[Union[Path, str]], device, dtype, process_group, dummy_fs, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, weights_loader: Optional[WeightsLoader] = None, ): routing = {} self.dummy_fs = dummy_fs for filename in filenames: with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f: for k in f.keys(): if k in routing: raise RuntimeError( f"Key {k} was found in multiple files: {filename} and {routing[k]}" ) routing[k] = filename if aliases is None: aliases = {} self.aliases = aliases self.routing = routing self.device = device self.dtype = dtype self.process_group = process_group self.prefix = prefix self.weights_loader = ( # We don't need to get linear layers, so just wrap raw tensors. DefaultWeightsLoader(lambda x: x) if weights_loader is None else weights_loader ) self._handles = {} def _get_handle(self, filename: Union[Path, str]): if filename in self._handles: return self._handles[filename] else: handle = mock_handle(filename, self.device, self.dtype) self._handles[filename] = handle return handle def get_shape(self, tensor_name: str): filename, _ = self.get_filename(tensor_name) handle = self._get_handle(filename) return handle.get_slice(tensor_name).get_shape() def get_tensor(self, tensor_name: str): filename, _ = self.get_filename(tensor_name) handle = self._get_handle(filename) return handle.get_slice(tensor_name).tensor dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1) def test_weights(): weights = MockWeights( [ "test_weights", "test_weights_2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) assert weights.get_shape("layer.0.weight") == (2, 2) assert weights.get_tensor("layer.1337.weight").shape == (2, 4) def test_get_tensor(): weights = MockWeights( [ "test_weights", "test_weights_2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) assert torch.allclose( weights.get_tensor("layer.0.weight"), torch.tensor( [ [1, 2], [3, 4], ], dtype=torch.float32, ), ) assert torch.allclose( weights.get_tensor("layer.1337.weight"), torch.tensor( [ [1, 2, 3, 4], [5, 6, 7, 8], ], dtype=torch.float32, ), ) def test_get_weights_col_packed(): weights = MockWeights( [ "test_get_weights_col_packed", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) prefix = "weight" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, block_sizes=block_sizes, ) assert torch.allclose( w, torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.float32, ), ) def test_get_weights_col_packed_block_size(): weights = MockWeights( [ "test_get_weights_col_packed", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) prefix = "weight" block_sizes = 2 w = weights.get_weights_col_packed( prefix=prefix, block_sizes=block_sizes, ) assert torch.allclose( w, torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.float32, ), ) def test_get_weights_col_packed_block_size_arr(): weights = MockWeights( [ "test_get_weights_col_packed", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) prefix = "weight" block_sizes = [1, 1] w = weights.get_weights_col_packed( prefix=prefix, block_sizes=block_sizes, ) assert torch.allclose( w, torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.float32, ), ) def test_get_multi_weights_col(): weights = MockWeights( [ "test_get_multi_weights_col", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) prefixes = ["weight", "weight"] w = weights.get_multi_weights_col( prefixes=prefixes, dim=0, ) assert torch.allclose( w, torch.tensor( [ [1, 2], [3, 4], [5, 6], [7, 8], [1, 2], [3, 4], [5, 6], [7, 8], ], dtype=torch.float32, ), ) def test_get_weights_row(): weights = MockWeights( [ "test_get_weights_row", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) prefix = "weight" w = weights.get_weights_row( prefix=prefix, ) assert torch.allclose( w, torch.tensor( [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=torch.float32, ), ) # test_get_weights_col def test_get_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=gptq_weights_loader_awq, ) prefix = "weight" w = weights.get_weights_col( prefix=prefix, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), scales=torch.tensor( [[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16, ), g_idx=None, bits=8.0, groupsize=2.0, use_awq_kernel=True, use_exllama=False, ) assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" def test_get_weights_col_gtpq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=gptq_weights_loader, ) prefix = "weight" w = weights.get_weights_col( prefix=prefix, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, use_awq_kernel=False, use_exllama=False, ) assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" def test_get_weights_col_exl2(): weights = MockWeights( [ "test_get_weights_col_exl2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=Exl2WeightsLoader(), ) prefix = "weight" w = weights.get_weights_col( prefix=prefix, ) scaled_scale_max = 0.3906 * 256 expected_weight = Exl2Weight( q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), q_scale=torch.tensor([8], dtype=torch.int32), q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), q_groups=torch.tensor([4], dtype=torch.int16), ) assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" assert torch.allclose( w.q_scale_max, expected_weight.q_scale_max ), "q_scale_max mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" def test_get_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_marlin", ], device="cpu", dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=marlin_weights_loader, ) prefix = "weight" w = weights.get_weights_col( prefix=prefix, ) expected_weight = MarlinWeight( B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), ) assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch" # test_get_weights_col_packed def test_get_weights_col_packed_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_packed_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=gptq_weights_loader_awq, ) prefix = "weight" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, block_sizes=block_sizes, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=None, bits=8.0, groupsize=2.0, use_awq_kernel=True, use_exllama=False, ) assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @pytest.mark.skip(reason="Review expected functionality") def test_get_weights_col_packed_exl2(): weights = MockWeights( [ "test_get_weights_col_packed_exl2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=Exl2WeightsLoader(), ) prefix = "weight" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, block_sizes=block_sizes, ) scaled_scale_max = 0.3906 * 256 expected_weight = Exl2Weight( q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), q_scale=torch.tensor([8], dtype=torch.int32), q_invperm=torch.tensor([1], dtype=torch.int16), q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), q_groups=torch.tensor([4], dtype=torch.int16), ) assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" assert torch.allclose( w.q_scale_max, expected_weight.q_scale_max ), "q_scale_max mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" def test_get_weights_col_packed_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=gptq_weights_loader, ) prefixes = ["weight"] w = weights.get_multi_weights_col( prefixes=prefixes, dim=0, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, use_awq_kernel=False, use_exllama=False, ) assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" def test_get_weights_col_packed_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_marlin", ], device="cpu", dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=marlin_weights_loader, ) prefix = "weight" w = weights.get_multi_weights_col( prefixes=[prefix], dim=0, ) expected_weight = MarlinWeight( B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), ) print(expected_weight) assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch" # test_get_multi_weights_col def test_get_multi_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_multi_weights_col_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=gptq_weights_loader_awq, ) prefixes = ["weight"] w = weights.get_multi_weights_col( prefixes=prefixes, dim=0, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=None, bits=8.0, groupsize=2.0, use_awq_kernel=True, use_exllama=False, ) assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" def test_get_multi_weights_col_exl2(): weights = MockWeights( [ "test_get_multi_weights_col_exl2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=Exl2WeightsLoader(), ) prefix = "weight" try: weights.get_multi_weights_col( prefixes=[prefix], dim=0, ) except ValueError as e: assert e.args[0] == "get_multi_weights_col is not supported for exl2" def test_get_multi_weights_col_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=gptq_weights_loader, ) prefixes = ["weight"] w = weights.get_multi_weights_col( prefixes=prefixes, dim=0, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, use_awq_kernel=False, use_exllama=False, ) assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" def test_get_multi_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_marlin", ], device="cpu", dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=marlin_weights_loader, ) prefix = "weight" w = weights.get_multi_weights_col( prefixes=[prefix], dim=0, ) expected_weight = MarlinWeight( B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), ) assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch" # test_get_weights_row def test_get_weights_row_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=gptq_weights_loader_awq, ) prefix = "weight" w = weights.get_weights_row( prefix=prefix, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=None, bits=8.0, groupsize=2.0, use_awq_kernel=True, use_exllama=False, ) assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" def test_get_weights_row_exl2(): weights = MockWeights( [ "test_get_weights_row_exl2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=Exl2WeightsLoader(), ) prefix = "weight" w = weights.get_weights_row( prefix=prefix, ) print(w) scaled_scale_max = 0.3906 * 256 expected_weight = Exl2Weight( q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), q_scale=torch.tensor([8], dtype=torch.int32), q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16), q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), q_groups=torch.tensor([4], dtype=torch.int16), ) assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" assert torch.allclose( w.q_scale_max, expected_weight.q_scale_max ), "q_scale_max mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" def test_get_weights_row_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=gptq_weights_loader, ) prefix = "weight" w = weights.get_weights_row( prefix=prefix, ) expected_weight = GPTQWeight( qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, use_awq_kernel=False, use_exllama=False, ) assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" def test_get_weights_row_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_row_marlin", ], device="cpu", dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, weights_loader=marlin_weights_loader, ) prefix = "weight" w = weights.get_weights_row( prefix=prefix, ) expected_weight = MarlinWeight( B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), ) assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"