MoE Marlin: support `desc_act` for `groupsize != -1` (#2590)

This change uses the updated Marlin MoE kernel from vLLM to support
MoE with activation sorting and groups.
This commit is contained in:
Daniël de Kok 2024-09-30 19:40:25 +02:00 committed by GitHub
parent d1f257ac56
commit 1c84a30fe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 6 additions and 19 deletions

View File

@ -978,15 +978,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1727687740, "lastModified": 1727710820,
"narHash": "sha256-ssoGLmRoyQ+8d5utr5fwLox+/eQ789iVtUj1xrukIC0=", "narHash": "sha256-BuSafCxoFQhkp7lnvNtpquxSK43rIbnouL2HypIUC+o=",
"owner": "danieldk", "owner": "danieldk",
"repo": "tgi-nix", "repo": "tgi-nix",
"rev": "5e884ba50c26a7c93337bc0876f69da961c10374", "rev": "4f4dc4b85dd856fd7904e8e3e486a2ff153584a2",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "danieldk", "owner": "danieldk",
"ref": "moe-kernels-0.5.0",
"repo": "tgi-nix", "repo": "tgi-nix",
"type": "github" "type": "github"
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix"; tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {

View File

@ -109,7 +109,6 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
try: try:
qweight = weights.get_packed_sharded( qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -352,7 +351,7 @@ def repack_gptq_for_marlin(
scales = permute_scales(scales) scales = permute_scales(scales)
is_full_k = not (desc_act and sharded_infeatures) is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight( return GPTQMarlinWeight(
qweight=repacked, qweight=repacked,

View File

@ -249,12 +249,9 @@ class SparseMoELayer(nn.Module):
or ( or (
isinstance(weights.loader, GPTQMarlinWeightsLoader) isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm( and can_use_marlin_moe_gemm(
desc_act=weights.loader.desc_act,
groupsize=weights.loader.groupsize,
quant_method=weights.loader.quant_method, quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize, quantize=weights.loader.quantize,
sym=weights.loader.sym, sym=weights.loader.sym,
use_tp=weights.process_group.size() > 1,
) )
) )
) )

View File

@ -26,12 +26,9 @@ except Exception:
def can_use_marlin_moe_gemm( def can_use_marlin_moe_gemm(
*, *,
desc_act: bool,
groupsize: int,
quant_method: str, quant_method: str,
quantize: str, quantize: str,
sym: bool, sym: bool,
use_tp: bool,
): ):
return ( return (
SYSTEM == "cuda" SYSTEM == "cuda"
@ -40,16 +37,9 @@ def can_use_marlin_moe_gemm(
and quantize == "gptq" and quantize == "gptq"
and quant_method == "gptq" and quant_method == "gptq"
and sym and sym
and is_full_k(desc_act, groupsize, use_tp)
) )
def is_full_k(desc_act: bool, groupsize: int, use_tp: bool):
if groupsize == -1:
return True
return not (desc_act and use_tp)
@dataclass @dataclass
class GPTQMarlinMoEWeight: class GPTQMarlinMoEWeight:
qweight: torch.Tensor qweight: torch.Tensor