MoE Marlin: support `desc_act` for `groupsize != -1` (#2590)
This change uses the updated Marlin MoE kernel from vLLM to support MoE with activation sorting and groups.
This commit is contained in:
parent
d1f257ac56
commit
1c84a30fe6
|
@ -978,15 +978,16 @@
|
|||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1727687740,
|
||||
"narHash": "sha256-ssoGLmRoyQ+8d5utr5fwLox+/eQ789iVtUj1xrukIC0=",
|
||||
"lastModified": 1727710820,
|
||||
"narHash": "sha256-BuSafCxoFQhkp7lnvNtpquxSK43rIbnouL2HypIUC+o=",
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"rev": "5e884ba50c26a7c93337bc0876f69da961c10374",
|
||||
"rev": "4f4dc4b85dd856fd7904e8e3e486a2ff153584a2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"ref": "moe-kernels-0.5.0",
|
||||
"repo": "tgi-nix",
|
||||
"type": "github"
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
};
|
||||
nix-filter.url = "github:numtide/nix-filter";
|
||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
||||
tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
|
|
|
@ -109,7 +109,6 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
|||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
|
@ -352,7 +351,7 @@ def repack_gptq_for_marlin(
|
|||
|
||||
scales = permute_scales(scales)
|
||||
|
||||
is_full_k = not (desc_act and sharded_infeatures)
|
||||
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
||||
|
||||
return GPTQMarlinWeight(
|
||||
qweight=repacked,
|
||||
|
|
|
@ -249,12 +249,9 @@ class SparseMoELayer(nn.Module):
|
|||
or (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
and can_use_marlin_moe_gemm(
|
||||
desc_act=weights.loader.desc_act,
|
||||
groupsize=weights.loader.groupsize,
|
||||
quant_method=weights.loader.quant_method,
|
||||
quantize=weights.loader.quantize,
|
||||
sym=weights.loader.sym,
|
||||
use_tp=weights.process_group.size() > 1,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -26,12 +26,9 @@ except Exception:
|
|||
|
||||
def can_use_marlin_moe_gemm(
|
||||
*,
|
||||
desc_act: bool,
|
||||
groupsize: int,
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
use_tp: bool,
|
||||
):
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
|
@ -40,16 +37,9 @@ def can_use_marlin_moe_gemm(
|
|||
and quantize == "gptq"
|
||||
and quant_method == "gptq"
|
||||
and sym
|
||||
and is_full_k(desc_act, groupsize, use_tp)
|
||||
)
|
||||
|
||||
|
||||
def is_full_k(desc_act: bool, groupsize: int, use_tp: bool):
|
||||
if groupsize == -1:
|
||||
return True
|
||||
return not (desc_act and use_tp)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlinMoEWeight:
|
||||
qweight: torch.Tensor
|
||||
|
|
Loading…
Reference in New Issue