diff --git a/flake.lock b/flake.lock index fd6f3a54..e6361fda 100644 --- a/flake.lock +++ b/flake.lock @@ -978,15 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1727687740, - "narHash": "sha256-ssoGLmRoyQ+8d5utr5fwLox+/eQ789iVtUj1xrukIC0=", + "lastModified": 1727710820, + "narHash": "sha256-BuSafCxoFQhkp7lnvNtpquxSK43rIbnouL2HypIUC+o=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "5e884ba50c26a7c93337bc0876f69da961c10374", + "rev": "4f4dc4b85dd856fd7904e8e3e486a2ff153584a2", "type": "github" }, "original": { "owner": "danieldk", + "ref": "moe-kernels-0.5.0", "repo": "tgi-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index 42fb3c6a..be19e908 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:danieldk/tgi-nix"; + tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index c7663b60..0a785d94 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -109,7 +109,6 @@ class GPTQMarlinWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): - try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -352,7 +351,7 @@ def repack_gptq_for_marlin( scales = permute_scales(scales) - is_full_k = not (desc_act and sharded_infeatures) + is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures) return GPTQMarlinWeight( qweight=repacked, diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index ca71ebab..2c46ca02 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -249,12 +249,9 @@ class SparseMoELayer(nn.Module): or ( isinstance(weights.loader, GPTQMarlinWeightsLoader) and can_use_marlin_moe_gemm( - desc_act=weights.loader.desc_act, - groupsize=weights.loader.groupsize, quant_method=weights.loader.quant_method, quantize=weights.loader.quantize, sym=weights.loader.sym, - use_tp=weights.process_group.size() > 1, ) ) ) diff --git a/server/text_generation_server/layers/moe/gptq_marlin.py b/server/text_generation_server/layers/moe/gptq_marlin.py index 3fc06cb2..3217cdc2 100644 --- a/server/text_generation_server/layers/moe/gptq_marlin.py +++ b/server/text_generation_server/layers/moe/gptq_marlin.py @@ -26,12 +26,9 @@ except Exception: def can_use_marlin_moe_gemm( *, - desc_act: bool, - groupsize: int, quant_method: str, quantize: str, sym: bool, - use_tp: bool, ): return ( SYSTEM == "cuda" @@ -40,16 +37,9 @@ def can_use_marlin_moe_gemm( and quantize == "gptq" and quant_method == "gptq" and sym - and is_full_k(desc_act, groupsize, use_tp) ) -def is_full_k(desc_act: bool, groupsize: int, use_tp: bool): - if groupsize == -1: - return True - return not (desc_act and use_tp) - - @dataclass class GPTQMarlinMoEWeight: qweight: torch.Tensor