MoE Marlin: support `desc_act` for `groupsize != -1` (#2590)

This change uses the updated Marlin MoE kernel from vLLM to support
MoE with activation sorting and groups.
This commit is contained in:
Daniël de Kok 2024-09-30 19:40:25 +02:00 committed by GitHub
parent d1f257ac56
commit 1c84a30fe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 6 additions and 19 deletions

View File

@ -978,15 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1727687740,
"narHash": "sha256-ssoGLmRoyQ+8d5utr5fwLox+/eQ789iVtUj1xrukIC0=",
"lastModified": 1727710820,
"narHash": "sha256-BuSafCxoFQhkp7lnvNtpquxSK43rIbnouL2HypIUC+o=",
"owner": "danieldk",
"repo": "tgi-nix",
"rev": "5e884ba50c26a7c93337bc0876f69da961c10374",
"rev": "4f4dc4b85dd856fd7904e8e3e486a2ff153584a2",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "moe-kernels-0.5.0",
"repo": "tgi-nix",
"type": "github"
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix";
tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {

View File

@ -109,7 +109,6 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -352,7 +351,7 @@ def repack_gptq_for_marlin(
scales = permute_scales(scales)
is_full_k = not (desc_act and sharded_infeatures)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight(
qweight=repacked,

View File

@ -249,12 +249,9 @@ class SparseMoELayer(nn.Module):
or (
isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
desc_act=weights.loader.desc_act,
groupsize=weights.loader.groupsize,
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
use_tp=weights.process_group.size() > 1,
)
)
)

View File

@ -26,12 +26,9 @@ except Exception:
def can_use_marlin_moe_gemm(
*,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
use_tp: bool,
):
return (
SYSTEM == "cuda"
@ -40,16 +37,9 @@ def can_use_marlin_moe_gemm(
and quantize == "gptq"
and quant_method == "gptq"
and sym
and is_full_k(desc_act, groupsize, use_tp)
)
def is_full_k(desc_act: bool, groupsize: int, use_tp: bool):
if groupsize == -1:
return True
return not (desc_act and use_tp)
@dataclass
class GPTQMarlinMoEWeight:
qweight: torch.Tensor