MoE Marlin: support `desc_act` for `groupsize != -1` (#2590)
This change uses the updated Marlin MoE kernel from vLLM to support MoE with activation sorting and groups.
This commit is contained in:
parent
d1f257ac56
commit
1c84a30fe6
|
@ -978,15 +978,16 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1727687740,
|
"lastModified": 1727710820,
|
||||||
"narHash": "sha256-ssoGLmRoyQ+8d5utr5fwLox+/eQ789iVtUj1xrukIC0=",
|
"narHash": "sha256-BuSafCxoFQhkp7lnvNtpquxSK43rIbnouL2HypIUC+o=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"rev": "5e884ba50c26a7c93337bc0876f69da961c10374",
|
"rev": "4f4dc4b85dd856fd7904e8e3e486a2ff153584a2",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
|
"ref": "moe-kernels-0.5.0",
|
||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
|
|
@ -109,7 +109,6 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_packed_sharded(
|
qweight = weights.get_packed_sharded(
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
|
@ -352,7 +351,7 @@ def repack_gptq_for_marlin(
|
||||||
|
|
||||||
scales = permute_scales(scales)
|
scales = permute_scales(scales)
|
||||||
|
|
||||||
is_full_k = not (desc_act and sharded_infeatures)
|
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
||||||
|
|
||||||
return GPTQMarlinWeight(
|
return GPTQMarlinWeight(
|
||||||
qweight=repacked,
|
qweight=repacked,
|
||||||
|
|
|
@ -249,12 +249,9 @@ class SparseMoELayer(nn.Module):
|
||||||
or (
|
or (
|
||||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||||
and can_use_marlin_moe_gemm(
|
and can_use_marlin_moe_gemm(
|
||||||
desc_act=weights.loader.desc_act,
|
|
||||||
groupsize=weights.loader.groupsize,
|
|
||||||
quant_method=weights.loader.quant_method,
|
quant_method=weights.loader.quant_method,
|
||||||
quantize=weights.loader.quantize,
|
quantize=weights.loader.quantize,
|
||||||
sym=weights.loader.sym,
|
sym=weights.loader.sym,
|
||||||
use_tp=weights.process_group.size() > 1,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,12 +26,9 @@ except Exception:
|
||||||
|
|
||||||
def can_use_marlin_moe_gemm(
|
def can_use_marlin_moe_gemm(
|
||||||
*,
|
*,
|
||||||
desc_act: bool,
|
|
||||||
groupsize: int,
|
|
||||||
quant_method: str,
|
quant_method: str,
|
||||||
quantize: str,
|
quantize: str,
|
||||||
sym: bool,
|
sym: bool,
|
||||||
use_tp: bool,
|
|
||||||
):
|
):
|
||||||
return (
|
return (
|
||||||
SYSTEM == "cuda"
|
SYSTEM == "cuda"
|
||||||
|
@ -40,16 +37,9 @@ def can_use_marlin_moe_gemm(
|
||||||
and quantize == "gptq"
|
and quantize == "gptq"
|
||||||
and quant_method == "gptq"
|
and quant_method == "gptq"
|
||||||
and sym
|
and sym
|
||||||
and is_full_k(desc_act, groupsize, use_tp)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_full_k(desc_act: bool, groupsize: int, use_tp: bool):
|
|
||||||
if groupsize == -1:
|
|
||||||
return True
|
|
||||||
return not (desc_act and use_tp)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTQMarlinMoEWeight:
|
class GPTQMarlinMoEWeight:
|
||||||
qweight: torch.Tensor
|
qweight: torch.Tensor
|
||||||
|
|
Loading…
Reference in New Issue