[WIP] Adding GPTQ support for llama
This commit is contained in:
parent
4f6d038c0b
commit
2c9e1171bc
|
@ -1,4 +1,4 @@
|
|||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
|
@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
|||
|
||||
mod env_runtime;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Quantization{
|
||||
Bitsandbytes,
|
||||
Gptq
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Quantization {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// To keep in track with `server`.
|
||||
match self{
|
||||
Quantization::Bitsandbytes => {
|
||||
write!(f, "bitsandbytes")
|
||||
},
|
||||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
|
@ -46,10 +66,10 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
num_shard: Option<usize>,
|
||||
|
||||
/// Wether you want the model to be quantized or not. This will use bitsandbytes for
|
||||
/// quantization on the fly.
|
||||
#[clap(long, env)]
|
||||
quantize: bool,
|
||||
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
|
||||
/// quantization on the fly, or `gptq`
|
||||
#[clap(long, env, value_enum)]
|
||||
quantize: Option<Quantization>,
|
||||
|
||||
/// The maximum amount of concurrent requests for this particular deployment.
|
||||
/// Having a low limit will refuse clients requests instead of having them
|
||||
|
@ -218,7 +238,7 @@ enum ShardStatus {
|
|||
fn shard_manager(
|
||||
model_id: String,
|
||||
revision: Option<String>,
|
||||
quantize: bool,
|
||||
quantize: Option<Quantization>,
|
||||
uds_path: String,
|
||||
rank: usize,
|
||||
world_size: usize,
|
||||
|
@ -257,8 +277,9 @@ fn shard_manager(
|
|||
shard_argv.push("--sharded".to_string());
|
||||
}
|
||||
|
||||
if quantize {
|
||||
shard_argv.push("--quantize".to_string())
|
||||
if let Some(quantize) = quantize {
|
||||
shard_argv.push("--quantize".to_string());
|
||||
shard_argv.push(quantize.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
|
@ -330,6 +351,7 @@ fn shard_manager(
|
|||
|
||||
// Start process
|
||||
tracing::info!("Starting shard {rank}");
|
||||
tracing::info!("Command {}", shard_argv.join(" "));
|
||||
let mut p = match Popen::create(
|
||||
&shard_argv,
|
||||
PopenConfig {
|
||||
|
@ -747,7 +769,6 @@ fn spawn_webserver(
|
|||
) -> Result<Popen, LauncherError> {
|
||||
// All shard started
|
||||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
let mut argv = vec![
|
||||
"text-generation-router".to_string(),
|
||||
"--max-concurrent-requests".to_string(),
|
||||
|
@ -811,6 +832,9 @@ fn spawn_webserver(
|
|||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
tracing::info!("Starting Webserver");
|
||||
tracing::info!("Command {}", argv.join(" "));
|
||||
tracing::info!("Env {:?}", env);
|
||||
let mut webserver = match Popen::create(
|
||||
&argv,
|
||||
PopenConfig {
|
||||
|
|
|
@ -15,7 +15,7 @@ def serve(
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
sharded: bool = False,
|
||||
quantize: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
|
|
|
@ -91,7 +91,7 @@ torch.set_grad_enabled(False)
|
|||
|
||||
|
||||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
|
||||
) -> Model:
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -33,6 +34,8 @@ import flash_attn_cuda
|
|||
import dropout_layer_norm
|
||||
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
# from safetensors.torch import load_file
|
||||
from safetensors import safe_open
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
|
@ -40,6 +43,12 @@ try:
|
|||
except ImportError as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
HAS_GPTQ = True
|
||||
try:
|
||||
from text_generation_server.quant.quant_linear import QuantLinear
|
||||
except ImportError as e:
|
||||
HAS_GPTQ = False
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
|
@ -102,10 +111,10 @@ class FastLinear(nn.Linear):
|
|||
) -> None:
|
||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||
self.quantized = False
|
||||
self.bnb_linear = None
|
||||
self.qlinear = None
|
||||
|
||||
def prepare_weights(self, quantize: bool = False):
|
||||
if quantize:
|
||||
def prepare_weights(self, layer=None, name=None, quantize: Optional[str] = None):
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
|
@ -114,17 +123,114 @@ class FastLinear(nn.Linear):
|
|||
)
|
||||
|
||||
self.quantized = True
|
||||
self.bnb_linear = Linear8bitLt(
|
||||
self.qlinear = Linear8bitLt(
|
||||
self.in_features,
|
||||
self.out_features,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
bias=False,
|
||||
)
|
||||
# Copy data to bnb_linear
|
||||
self.bnb_linear.weight.data = self.weight.data
|
||||
# Copy data to qlinear
|
||||
self.qlinear.weight.data = self.weight.data
|
||||
if self.bias is not None:
|
||||
self.bnb_linear.bias = nn.Parameter(self.bias)
|
||||
self.qlinear.bias = nn.Parameter(self.bias)
|
||||
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
elif quantize == "gptq":
|
||||
if not HAS_GPTQ:
|
||||
raise ImportError(
|
||||
"gptq is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install gptq`."
|
||||
)
|
||||
self.quantized = True
|
||||
self.qlinear = QuantLinear(
|
||||
bits=4,
|
||||
groupsize=128,
|
||||
infeatures=self.in_features,
|
||||
outfeatures=self.out_features,
|
||||
bias=bool(self.bias),
|
||||
)
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
|
||||
def get_row_slice(f, name):
|
||||
slice_ = f.get_slice(name)
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
return tensor.contiguous()
|
||||
|
||||
def get_col_slice(f, name):
|
||||
slice_ = f.get_slice(name)
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
return tensor.contiguous()
|
||||
|
||||
if isinstance(self, TensorParallelRowLinear):
|
||||
get_slice = get_row_slice
|
||||
elif isinstance(self, TensorParallelColumnLinear):
|
||||
get_slice = get_col_slice
|
||||
elif isinstance(self, FastLinear):
|
||||
def get_slice(f, name):
|
||||
return f.get_tensor(name)
|
||||
else:
|
||||
raise ValueError("Need a specific class of Linear (TensorParallel, or regular Linear)")
|
||||
|
||||
with safe_open("/home/ubuntu/src/GPTQ-for-LLaMa/oasst-4bit-128g.safetensors", framework="pt", device=f"cuda:{rank}") as f:
|
||||
if name == 'self_attn.query_key_value':
|
||||
query_name = f'model.layers.{layer}.self_attn'
|
||||
self.qlinear.qweight[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.qweight")
|
||||
self.qlinear.qweight[:, self.out_features // 3:-self.out_features // 3] = get_slice(f, f"{query_name}.k_proj.qweight")
|
||||
self.qlinear.qweight[:,-self.out_features // 3: ] = get_slice(f, f"{query_name}.v_proj.qweight")
|
||||
|
||||
N = self.qlinear.qzeros.shape[1]
|
||||
self.qlinear.qzeros[:, : N // 3] = get_slice(f, f"{query_name}.q_proj.qzeros")
|
||||
self.qlinear.qzeros[:, N // 3:-N // 3] = get_slice(f, f"{query_name}.k_proj.qzeros")
|
||||
self.qlinear.qzeros[:,-N // 3: ] = get_slice(f, f"{query_name}.v_proj.qzeros")
|
||||
|
||||
self.qlinear.scales[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.scales")
|
||||
self.qlinear.scales[:, self.out_features // 3:-self.out_features // 3] = get_slice(f, f"{query_name}.k_proj.scales")
|
||||
self.qlinear.scales[:,-self.out_features // 3: ] = get_slice(f, f"{query_name}.v_proj.scales")
|
||||
torch.testing.assert_close(f.get_tensor(f"{query_name}.q_proj.g_idx"), f.get_tensor(f"{query_name}.k_proj.g_idx"))
|
||||
torch.testing.assert_close(f.get_tensor(f"{query_name}.q_proj.g_idx"), f.get_tensor(f"{query_name}.v_proj.g_idx"))
|
||||
self.qlinear.g_idx[:] = f.get_tensor(f"{query_name}.q_proj.g_idx")
|
||||
|
||||
elif name == "self_attn.o_proj":
|
||||
self.qlinear.qweight[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.qweight")
|
||||
self.qlinear.qzeros[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.qzeros")
|
||||
self.qlinear.scales[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.scales")
|
||||
self.qlinear.g_idx[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.g_idx")
|
||||
|
||||
elif name == "mlp.gate_up_proj":
|
||||
N = self.qlinear.qweight.shape[1] // 2
|
||||
self.qlinear.qweight[:, :N] = get_slice(f, f"model.layers.{layer}.mlp.gate_proj.qweight")
|
||||
self.qlinear.qweight[:, N:] = get_slice(f, f"model.layers.{layer}.mlp.up_proj.qweight")
|
||||
|
||||
self.qlinear.scales[:, :N] = get_slice(f, f"model.layers.{layer}.mlp.gate_proj.scales")
|
||||
self.qlinear.scales[:, N:] = get_slice(f, f"model.layers.{layer}.mlp.up_proj.scales")
|
||||
|
||||
torch.testing.assert_close(f.get_tensor(f"model.layers.{layer}.mlp.gate_proj.g_idx"), f.get_tensor(f"model.layers.{layer}.mlp.up_proj.g_idx"))
|
||||
self.qlinear.g_idx[:] = f.get_tensor(f"model.layers.{layer}.mlp.gate_proj.g_idx")
|
||||
|
||||
N = self.qlinear.qzeros.shape[1] // 2
|
||||
self.qlinear.qzeros[:, N:] = get_slice(f, f"model.layers.{layer}.mlp.up_proj.qzeros")
|
||||
self.qlinear.qzeros[:, :N] = get_slice(f, f"model.layers.{layer}.mlp.gate_proj.qzeros")
|
||||
|
||||
elif name == "mlp.down_proj":
|
||||
self.qlinear.qweight[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.qweight")
|
||||
self.qlinear.qzeros[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.qzeros")
|
||||
self.qlinear.scales[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.scales")
|
||||
self.qlinear.g_idx[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.g_idx")
|
||||
else:
|
||||
raise ValueError("Not handled")
|
||||
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
|
@ -134,7 +240,7 @@ class FastLinear(nn.Linear):
|
|||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
return self.bnb_linear(input)
|
||||
return self.qlinear(input)
|
||||
else:
|
||||
if self.bias is not None:
|
||||
return torch.addmm(self.bias, input, self.weight)
|
||||
|
@ -542,12 +648,12 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
||||
self.embed_tokens.add_null_idx()
|
||||
for layer in self.layers:
|
||||
for i, layer in enumerate(self.layers):
|
||||
layer: FlashLlamaLayer
|
||||
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
|
||||
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
|
||||
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
|
||||
layer.mlp.down_proj.prepare_weights(load_in_8bit)
|
||||
layer.self_attn.query_key_value.prepare_weights(i, "self_attn.query_key_value", load_in_8bit)
|
||||
layer.self_attn.o_proj.prepare_weights(i, "self_attn.o_proj", load_in_8bit)
|
||||
layer.mlp.gate_up_proj.prepare_weights(i, "mlp.gate_up_proj", load_in_8bit)
|
||||
layer.mlp.down_proj.prepare_weights(i, "mlp.down_proj", load_in_8bit)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -28,7 +28,7 @@ tracer = trace.get_tracer(__name__)
|
|||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=None):
|
||||
self.past_pad = None
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -154,7 +154,7 @@ class FlashLlama(FlashCausalLM):
|
|||
|
||||
class FlashLlamaSharded(FlashLlama):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None
|
||||
):
|
||||
self.past_pad = None
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
|
@ -177,13 +177,13 @@ class FlashLlamaSharded(FlashLlama):
|
|||
revision=revision,
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
# torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashLlamaForCausalLM(config, process_group=self.process_group)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
# torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .quantizer import Quantizer
|
||||
from .fused_attn import QuantLlamaAttention, make_quant_attn
|
||||
from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused
|
||||
from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear
|
|
@ -0,0 +1,193 @@
|
|||
#https://github.com/fpgaminer/GPTQ-triton
|
||||
"""
|
||||
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import math
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
class Autotuner(triton.KernelInterface):
|
||||
|
||||
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
|
||||
'''
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.nearest_power_of_two = nearest_power_of_two
|
||||
self.cache = {}
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
|
||||
if 'early_config_prune' in prune_configs_by:
|
||||
early_config_prune = prune_configs_by['early_config_prune']
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
self.fn = fn
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols.")
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
|
||||
try:
|
||||
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
|
||||
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
|
||||
return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
|
||||
except triton.compiler.OutOfResources:
|
||||
return (float('inf'), float('inf'), float('inf'))
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
key = tuple(args[i] for i in self.key_idx)
|
||||
|
||||
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
|
||||
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
||||
if self.nearest_power_of_two:
|
||||
key = tuple([2**int(math.log2(x) + 0.5) for x in key])
|
||||
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
pruned_configs = self.configs
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
for config in self.prune_configs(kwargs):
|
||||
self.fn.warmup(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
|
||||
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def matmul248_kernel_config_pruner(configs, nargs):
|
||||
"""
|
||||
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
|
||||
"""
|
||||
m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
|
||||
n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
|
||||
k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
|
||||
|
||||
used = set()
|
||||
for config in configs:
|
||||
block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
|
||||
block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
|
||||
block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
|
||||
group_size_m = config.kwargs['GROUP_SIZE_M']
|
||||
|
||||
if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
|
||||
continue
|
||||
|
||||
used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
|
||||
yield triton.Config({
|
||||
'BLOCK_SIZE_M': block_size_m,
|
||||
'BLOCK_SIZE_N': block_size_n,
|
||||
'BLOCK_SIZE_K': block_size_k,
|
||||
'GROUP_SIZE_M': group_size_m
|
||||
},
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps)
|
|
@ -0,0 +1,123 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
from .quant_linear import *
|
||||
|
||||
|
||||
class QuantLlamaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
rotary_emb,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.hidden_size:
|
||||
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {num_heads}).")
|
||||
self.qkv_proj = qkv_proj
|
||||
self.o_proj = o_proj
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
def _shape(self, tensor, seq_len, bsz):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
is_causal = past_key_value is None
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
if use_cache:
|
||||
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
|
||||
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False):
|
||||
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def make_quant_attn(model):
|
||||
"""
|
||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||
"""
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, LlamaAttention):
|
||||
continue
|
||||
|
||||
q_proj = m.q_proj
|
||||
k_proj = m.k_proj
|
||||
v_proj = m.v_proj
|
||||
|
||||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
|
||||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
|
||||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
|
||||
|
||||
qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False)
|
||||
qkv_layer.qweight = qweights
|
||||
qkv_layer.qzeros = qzeros
|
||||
qkv_layer.scales = scales
|
||||
qkv_layer.g_idx = g_idx
|
||||
qkv_layer.bias = bias
|
||||
|
||||
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
|
||||
|
||||
if '.' in name:
|
||||
parent_name = name.rsplit('.', 1)[0]
|
||||
child_name = name[len(parent_name) + 1:]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ''
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
|
||||
|
||||
setattr(parent, child_name, attn)
|
|
@ -0,0 +1,288 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from . import custom_autotune
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 256,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4), # 3090
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 16,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4), # 3090
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4), # 3090
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 16,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4), # 3090
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4), # 3090
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
||||
'perf_model': None,
|
||||
'top_k': None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn,
|
||||
stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
"""
|
||||
Computes: C = silu(A * B1) * (A * B2)
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (1, N) float16
|
||||
zeros is of shape (1, N//8) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
g1_ptrs = g1_ptr + offs_k
|
||||
g2_ptrs = g2_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales1_ptrs = scales1_ptr + offs_bn[None, :]
|
||||
scales2_ptrs = scales2_ptr + offs_bn[None, :]
|
||||
zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, num_pid_k):
|
||||
g1_idx = tl.load(g1_ptrs)
|
||||
g2_idx = tl.load(g2_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
|
||||
|
||||
zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
|
||||
zeros1 = (zeros1 + 1)
|
||||
|
||||
zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
|
||||
zeros2 = (zeros2 + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
b2 = tl.load(b2_ptrs)
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b1 = (b1 - zeros1) * scales1 # Scale and shift
|
||||
accumulator1 += tl.dot(a, b1)
|
||||
|
||||
b2 = (b2 >> shifter[:, None]) & maxq
|
||||
b2 = (b2 - zeros2) * scales2
|
||||
accumulator2 += tl.dot(a, b2)
|
||||
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g1_ptrs += BLOCK_SIZE_K
|
||||
g2_ptrs += BLOCK_SIZE_K
|
||||
|
||||
accumulator1 = silu(accumulator1)
|
||||
c = accumulator1 * accumulator2
|
||||
c = c.to(tl.float16)
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
@triton.jit
|
||||
def silu(x):
|
||||
return x * tl.sigmoid(x)
|
||||
except:
|
||||
print('triton not installed.')
|
||||
|
||||
|
||||
class QuantLlamaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gate_proj,
|
||||
down_proj,
|
||||
up_proj,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
|
||||
self.register_buffer('gate_proj_scales', gate_proj.scales)
|
||||
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
|
||||
self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)
|
||||
self.register_buffer('up_proj_qweight', up_proj.qweight)
|
||||
self.register_buffer('up_proj_scales', up_proj.scales)
|
||||
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
|
||||
self.register_buffer('up_proj_g_idx', up_proj.g_idx)
|
||||
|
||||
self.infeatures = gate_proj.infeatures
|
||||
self.intermediate_size = gate_proj.outfeatures
|
||||
self.outfeatures = down_proj.outfeatures
|
||||
self.bits = gate_proj.bits
|
||||
self.maxq = gate_proj.maxq
|
||||
|
||||
self.down_proj = down_proj
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.triton_llama_mlp(x))
|
||||
|
||||
def triton_llama_mlp(self, x):
|
||||
with torch.cuda.device(x.device):
|
||||
out_shape = x.shape[:-1] + (self.intermediate_size, )
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
M, K = x.shape
|
||||
N = self.intermediate_size
|
||||
c = torch.empty((M, N), device=x.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||
fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales,
|
||||
self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0),
|
||||
self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0))
|
||||
c = c.reshape(out_shape)
|
||||
return c
|
||||
|
||||
def fused2cuda(self):
|
||||
self.gate_proj_qweight = self.gate_proj_qweight.cuda()
|
||||
self.gate_proj_scales = self.gate_proj_scales.cuda()
|
||||
self.gate_proj_qzeros = self.gate_proj_qzeros.cuda()
|
||||
self.gate_proj_g_idx = self.gate_proj_g_idx.cuda()
|
||||
self.up_proj_qweight = self.up_proj_qweight.cuda()
|
||||
self.up_proj_scales = self.up_proj_scales.cuda()
|
||||
self.up_proj_qzeros = self.up_proj_qzeros.cuda()
|
||||
self.up_proj_g_idx = self.up_proj_g_idx.cuda()
|
||||
|
||||
def fused2cpu(self):
|
||||
self.gate_proj_qweight = self.gate_proj_qweight.cpu()
|
||||
self.gate_proj_scales = self.gate_proj_scales.cpu()
|
||||
self.gate_proj_qzeros = self.gate_proj_qzeros.cpu()
|
||||
self.gate_proj_g_idx = self.gate_proj_g_idx.cpu()
|
||||
self.up_proj_qweight = self.up_proj_qweight.cpu()
|
||||
self.up_proj_scales = self.up_proj_scales.cpu()
|
||||
self.up_proj_qzeros = self.up_proj_qzeros.cpu()
|
||||
self.up_proj_g_idx = self.up_proj_g_idx.cpu()
|
||||
|
||||
|
||||
def make_fused_mlp(m, parent_name=''):
|
||||
"""
|
||||
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
|
||||
"""
|
||||
if isinstance(m, LlamaMLP):
|
||||
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
|
||||
|
||||
for name, child in m.named_children():
|
||||
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
|
||||
|
||||
if isinstance(child, QuantLlamaMLP):
|
||||
setattr(m, name, child)
|
||||
return m
|
||||
|
||||
|
||||
def autotune_warmup_fused(model):
|
||||
"""
|
||||
Pre-tunes the quantized kernel
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
|
||||
kn_values = {}
|
||||
|
||||
for _, m in model.named_modules():
|
||||
if not isinstance(m, QuantLlamaMLP):
|
||||
continue
|
||||
|
||||
k = m.infeatures
|
||||
n = m.intermediate_size
|
||||
|
||||
m.fused2cuda()
|
||||
if (k, n) not in kn_values:
|
||||
kn_values[(k, n)] = m
|
||||
|
||||
print(f'Found {len(kn_values)} unique fused mlp KN values.')
|
||||
|
||||
print('Warming up autotune cache ...')
|
||||
with torch.no_grad():
|
||||
for m in tqdm(range(0, 12)):
|
||||
m = 2**m # [1, 2048]
|
||||
for (k, n), (modules) in kn_values.items():
|
||||
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
|
||||
modules.triton_llama_mlp(a)
|
||||
|
||||
for (k, n), (modules) in kn_values.items():
|
||||
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
|
||||
modules.fused2cpu()
|
||||
del kn_values
|
|
@ -0,0 +1,423 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from . import custom_autotune
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=3, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
||||
'perf_model': None,
|
||||
'top_k': None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_ptrs += BLOCK_SIZE_K
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
@custom_autotune.autotune(configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 256,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=3, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True)
|
||||
@triton.jit
|
||||
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
|
||||
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, N) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, K) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_k
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_k = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_bk
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||
|
||||
shifter = (offs_bk % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
for n in range(0, num_pid_n):
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
b = tl.trans(b)
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_N
|
||||
b_ptrs += BLOCK_SIZE_N
|
||||
scales_ptrs += BLOCK_SIZE_N
|
||||
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
except:
|
||||
print('trioton not installed.')
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
|
||||
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
||||
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
||||
return output
|
||||
|
||||
|
||||
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output_dim = (qweight.shape[0] * 32) // bits
|
||||
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
|
||||
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
||||
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||
ctx.bits, ctx.maxq = bits, maxq
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, qzeros, g_idx = ctx.saved_tensors
|
||||
bits, maxq = ctx.bits, ctx.maxq
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return grad_input, None, None, None, None, None, None
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32).cuda())
|
||||
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32).cuda())
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16).cuda())
|
||||
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32).cuda())
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures, )
|
||||
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
def make_quant_linear(module, names, bits, groupsize, name=''):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
for attr in dir(module):
|
||||
tmp = getattr(module, attr)
|
||||
name1 = name + '.' + attr if name != '' else attr
|
||||
if name1 in names:
|
||||
delattr(module, attr)
|
||||
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
|
||||
for name1, child in module.named_children():
|
||||
make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
||||
|
||||
|
||||
def autotune_warmup_linear(model, transpose=False):
|
||||
"""
|
||||
Pre-tunes the quantized kernel
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
|
||||
kn_values = {}
|
||||
|
||||
for _, m in model.named_modules():
|
||||
if not isinstance(m, QuantLinear):
|
||||
continue
|
||||
|
||||
k = m.infeatures
|
||||
n = m.outfeatures
|
||||
|
||||
if (k, n) not in kn_values:
|
||||
kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
|
||||
|
||||
print(f'Found {len(kn_values)} unique KN Linear values.')
|
||||
|
||||
print('Warming up autotune cache ...')
|
||||
with torch.no_grad():
|
||||
for m in tqdm(range(0, 12)):
|
||||
m = 2**m # [1, 2048]
|
||||
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
|
||||
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
|
||||
matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
if transpose:
|
||||
a = torch.randn(m, n, dtype=torch.float16, device='cuda')
|
||||
transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
del kn_values
|
|
@ -0,0 +1,127 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class Quantizer(nn.Module):
|
||||
|
||||
def __init__(self, shape=1):
|
||||
super(Quantizer, self).__init__()
|
||||
self.register_buffer('maxq', torch.tensor(0))
|
||||
self.register_buffer('scale', torch.zeros(shape))
|
||||
self.register_buffer('zero', torch.zeros(shape))
|
||||
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
|
||||
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
self.mse = mse
|
||||
self.norm = norm
|
||||
self.grid = grid
|
||||
self.maxshrink = maxshrink
|
||||
if trits:
|
||||
self.maxq = torch.tensor(-1)
|
||||
self.scale = torch.zeros_like(self.scale)
|
||||
|
||||
def _quantize(self, x, scale, zero, maxq):
|
||||
if maxq < 0:
|
||||
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
|
||||
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
||||
return scale * (q - zero)
|
||||
|
||||
def find_params(self, x, weight=False):
|
||||
dev = x.device
|
||||
self.maxq = self.maxq.to(dev)
|
||||
|
||||
shape = x.shape
|
||||
if self.perchannel:
|
||||
if weight:
|
||||
x = x.flatten(1)
|
||||
else:
|
||||
if len(shape) == 4:
|
||||
x = x.permute([1, 0, 2, 3])
|
||||
x = x.flatten(1)
|
||||
if len(shape) == 3:
|
||||
x = x.reshape((-1, shape[-1])).t()
|
||||
if len(shape) == 2:
|
||||
x = x.t()
|
||||
else:
|
||||
x = x.flatten().unsqueeze(0)
|
||||
|
||||
tmp = torch.zeros(x.shape[0], device=dev)
|
||||
xmin = torch.minimum(x.min(1)[0], tmp)
|
||||
xmax = torch.maximum(x.max(1)[0], tmp)
|
||||
|
||||
if self.sym:
|
||||
xmax = torch.maximum(torch.abs(xmin), xmax)
|
||||
tmp = xmin < 0
|
||||
if torch.any(tmp):
|
||||
xmin[tmp] = -xmax[tmp]
|
||||
tmp = (xmin == 0) & (xmax == 0)
|
||||
xmin[tmp] = -1
|
||||
xmax[tmp] = +1
|
||||
|
||||
if self.maxq < 0:
|
||||
self.scale = xmax
|
||||
self.zero = xmin
|
||||
else:
|
||||
self.scale = (xmax - xmin) / self.maxq
|
||||
if self.sym:
|
||||
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
||||
else:
|
||||
self.zero = torch.round(-xmin / self.scale)
|
||||
|
||||
if self.mse:
|
||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
||||
for i in range(int(self.maxshrink * self.grid)):
|
||||
p = 1 - i / self.grid
|
||||
xmin1 = p * xmin
|
||||
xmax1 = p * xmax
|
||||
scale1 = (xmax1 - xmin1) / self.maxq
|
||||
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
||||
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
||||
q -= x
|
||||
q.abs_()
|
||||
q.pow_(self.norm)
|
||||
err = torch.sum(q, 1)
|
||||
tmp = err < best
|
||||
if torch.any(tmp):
|
||||
best[tmp] = err[tmp]
|
||||
self.scale[tmp] = scale1[tmp]
|
||||
self.zero[tmp] = zero1[tmp]
|
||||
if not self.perchannel:
|
||||
if weight:
|
||||
tmp = shape[0]
|
||||
else:
|
||||
tmp = shape[1] if len(shape) != 3 else shape[2]
|
||||
self.scale = self.scale.repeat(tmp)
|
||||
self.zero = self.zero.repeat(tmp)
|
||||
|
||||
if weight:
|
||||
shape = [-1] + [1] * (len(shape) - 1)
|
||||
self.scale = self.scale.reshape(shape)
|
||||
self.zero = self.zero.reshape(shape)
|
||||
return
|
||||
if len(shape) == 4:
|
||||
self.scale = self.scale.reshape((1, -1, 1, 1))
|
||||
self.zero = self.zero.reshape((1, -1, 1, 1))
|
||||
if len(shape) == 3:
|
||||
self.scale = self.scale.reshape((1, 1, -1))
|
||||
self.zero = self.zero.reshape((1, 1, -1))
|
||||
if len(shape) == 2:
|
||||
self.scale = self.scale.unsqueeze(0)
|
||||
self.zero = self.zero.unsqueeze(0)
|
||||
|
||||
def quantize(self, x):
|
||||
if self.ready():
|
||||
return self._quantize(x, self.scale, self.zero, self.maxq)
|
||||
|
||||
return x
|
||||
|
||||
def enabled(self):
|
||||
return self.maxq > 0
|
||||
|
||||
def ready(self):
|
||||
return torch.all(self.scale != 0)
|
|
@ -0,0 +1,423 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from . import custom_autotune
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=3, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
||||
'perf_model': None,
|
||||
'top_k': None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_ptrs += BLOCK_SIZE_K
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
@custom_autotune.autotune(configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 256,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=3, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True)
|
||||
@triton.jit
|
||||
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
|
||||
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, N) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, K) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_k
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_k = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_bk
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||
|
||||
shifter = (offs_bk % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
for n in range(0, num_pid_n):
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
b = tl.trans(b)
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_N
|
||||
b_ptrs += BLOCK_SIZE_N
|
||||
scales_ptrs += BLOCK_SIZE_N
|
||||
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
except:
|
||||
print('trioton not installed.')
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
|
||||
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
||||
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
||||
return output
|
||||
|
||||
|
||||
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output_dim = (qweight.shape[0] * 32) // bits
|
||||
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
|
||||
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
||||
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||
ctx.bits, ctx.maxq = bits, maxq
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, qzeros, g_idx = ctx.saved_tensors
|
||||
bits, maxq = ctx.bits, ctx.maxq
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return grad_input, None, None, None, None, None, None
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
||||
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures, )
|
||||
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
def make_quant_linear(module, names, bits, groupsize, name=''):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
for attr in dir(module):
|
||||
tmp = getattr(module, attr)
|
||||
name1 = name + '.' + attr if name != '' else attr
|
||||
if name1 in names:
|
||||
delattr(module, attr)
|
||||
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
|
||||
for name1, child in module.named_children():
|
||||
make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
||||
|
||||
|
||||
def autotune_warmup_linear(model, transpose=False):
|
||||
"""
|
||||
Pre-tunes the quantized kernel
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
|
||||
kn_values = {}
|
||||
|
||||
for _, m in model.named_modules():
|
||||
if not isinstance(m, QuantLinear):
|
||||
continue
|
||||
|
||||
k = m.infeatures
|
||||
n = m.outfeatures
|
||||
|
||||
if (k, n) not in kn_values:
|
||||
kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
|
||||
|
||||
print(f'Found {len(kn_values)} unique KN Linear values.')
|
||||
|
||||
print('Warming up autotune cache ...')
|
||||
with torch.no_grad():
|
||||
for m in tqdm(range(0, 12)):
|
||||
m = 2**m # [1, 2048]
|
||||
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
|
||||
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
|
||||
matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
if transpose:
|
||||
a = torch.randn(m, n, dtype=torch.float16, device='cuda')
|
||||
transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
del kn_values
|
|
@ -100,14 +100,14 @@ def serve(
|
|||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: bool,
|
||||
quantize: Optional[str],
|
||||
uds_path: Path,
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
|
|
|
@ -4,6 +4,13 @@ import torch
|
|||
from datetime import timedelta
|
||||
|
||||
|
||||
class Fake:
|
||||
def size(self):
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
def rank(self):
|
||||
return int(os.getenv("RANK", "0"))
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
@ -33,3 +40,4 @@ def initialize_torch_distributed():
|
|||
)
|
||||
|
||||
return torch.distributed.group.WORLD, rank, world_size
|
||||
# return Fake(), rank, world_size
|
||||
|
|
Loading…
Reference in New Issue