[WIP] Adding GPTQ support for llama

This commit is contained in:
Ubuntu 2023-05-02 17:07:33 +00:00
parent 4f6d038c0b
commit 2c9e1171bc
14 changed files with 1749 additions and 30 deletions

View File

@ -1,4 +1,4 @@
use clap::Parser;
use clap::{Parser, ValueEnum};
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization{
Bitsandbytes,
Gptq
}
impl std::fmt::Display for Quantization {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self{
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
},
Quantization::Gptq => {
write!(f, "gptq")
}
}
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@ -46,10 +66,10 @@ struct Args {
#[clap(long, env)]
num_shard: Option<usize>,
/// Wether you want the model to be quantized or not. This will use bitsandbytes for
/// quantization on the fly.
#[clap(long, env)]
quantize: bool,
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// The maximum amount of concurrent requests for this particular deployment.
/// Having a low limit will refuse clients requests instead of having them
@ -218,7 +238,7 @@ enum ShardStatus {
fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: bool,
quantize: Option<Quantization>,
uds_path: String,
rank: usize,
world_size: usize,
@ -257,8 +277,9 @@ fn shard_manager(
shard_argv.push("--sharded".to_string());
}
if quantize {
shard_argv.push("--quantize".to_string())
if let Some(quantize) = quantize {
shard_argv.push("--quantize".to_string());
shard_argv.push(quantize.to_string())
}
// Model optional revision
@ -330,6 +351,7 @@ fn shard_manager(
// Start process
tracing::info!("Starting shard {rank}");
tracing::info!("Command {}", shard_argv.join(" "));
let mut p = match Popen::create(
&shard_argv,
PopenConfig {
@ -747,7 +769,6 @@ fn spawn_webserver(
) -> Result<Popen, LauncherError> {
// All shard started
// Start webserver
tracing::info!("Starting Webserver");
let mut argv = vec![
"text-generation-router".to_string(),
"--max-concurrent-requests".to_string(),
@ -811,6 +832,9 @@ fn spawn_webserver(
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
tracing::info!("Starting Webserver");
tracing::info!("Command {}", argv.join(" "));
tracing::info!("Env {:?}", env);
let mut webserver = match Popen::create(
&argv,
PopenConfig {

View File

@ -15,7 +15,7 @@ def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
quantize: Optional[str] = None,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,

View File

@ -91,7 +91,7 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
) -> Model:
if "facebook/galactica" in model_id:
if sharded:

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.distributed
@ -33,6 +34,8 @@ import flash_attn_cuda
import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding
# from safetensors.torch import load_file
from safetensors import safe_open
HAS_BITS_AND_BYTES = True
try:
@ -40,6 +43,12 @@ try:
except ImportError as e:
HAS_BITS_AND_BYTES = False
HAS_GPTQ = True
try:
from text_generation_server.quant.quant_linear import QuantLinear
except ImportError as e:
HAS_GPTQ = False
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
@ -102,10 +111,10 @@ class FastLinear(nn.Linear):
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
self.qlinear = None
def prepare_weights(self, quantize: bool = False):
if quantize:
def prepare_weights(self, layer=None, name=None, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -114,17 +123,114 @@ class FastLinear(nn.Linear):
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.qlinear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
# Copy data to qlinear
self.qlinear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
self.qlinear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None
elif quantize == "gptq":
if not HAS_GPTQ:
raise ImportError(
"gptq is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install gptq`."
)
self.quantized = True
self.qlinear = QuantLinear(
bits=4,
groupsize=128,
infeatures=self.in_features,
outfeatures=self.out_features,
bias=bool(self.bias),
)
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
def get_row_slice(f, name):
slice_ = f.get_slice(name)
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
return tensor.contiguous()
def get_col_slice(f, name):
slice_ = f.get_slice(name)
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
return tensor.contiguous()
if isinstance(self, TensorParallelRowLinear):
get_slice = get_row_slice
elif isinstance(self, TensorParallelColumnLinear):
get_slice = get_col_slice
elif isinstance(self, FastLinear):
def get_slice(f, name):
return f.get_tensor(name)
else:
raise ValueError("Need a specific class of Linear (TensorParallel, or regular Linear)")
with safe_open("/home/ubuntu/src/GPTQ-for-LLaMa/oasst-4bit-128g.safetensors", framework="pt", device=f"cuda:{rank}") as f:
if name == 'self_attn.query_key_value':
query_name = f'model.layers.{layer}.self_attn'
self.qlinear.qweight[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.qweight")
self.qlinear.qweight[:, self.out_features // 3:-self.out_features // 3] = get_slice(f, f"{query_name}.k_proj.qweight")
self.qlinear.qweight[:,-self.out_features // 3: ] = get_slice(f, f"{query_name}.v_proj.qweight")
N = self.qlinear.qzeros.shape[1]
self.qlinear.qzeros[:, : N // 3] = get_slice(f, f"{query_name}.q_proj.qzeros")
self.qlinear.qzeros[:, N // 3:-N // 3] = get_slice(f, f"{query_name}.k_proj.qzeros")
self.qlinear.qzeros[:,-N // 3: ] = get_slice(f, f"{query_name}.v_proj.qzeros")
self.qlinear.scales[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.scales")
self.qlinear.scales[:, self.out_features // 3:-self.out_features // 3] = get_slice(f, f"{query_name}.k_proj.scales")
self.qlinear.scales[:,-self.out_features // 3: ] = get_slice(f, f"{query_name}.v_proj.scales")
torch.testing.assert_close(f.get_tensor(f"{query_name}.q_proj.g_idx"), f.get_tensor(f"{query_name}.k_proj.g_idx"))
torch.testing.assert_close(f.get_tensor(f"{query_name}.q_proj.g_idx"), f.get_tensor(f"{query_name}.v_proj.g_idx"))
self.qlinear.g_idx[:] = f.get_tensor(f"{query_name}.q_proj.g_idx")
elif name == "self_attn.o_proj":
self.qlinear.qweight[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.qweight")
self.qlinear.qzeros[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.qzeros")
self.qlinear.scales[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.scales")
self.qlinear.g_idx[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.g_idx")
elif name == "mlp.gate_up_proj":
N = self.qlinear.qweight.shape[1] // 2
self.qlinear.qweight[:, :N] = get_slice(f, f"model.layers.{layer}.mlp.gate_proj.qweight")
self.qlinear.qweight[:, N:] = get_slice(f, f"model.layers.{layer}.mlp.up_proj.qweight")
self.qlinear.scales[:, :N] = get_slice(f, f"model.layers.{layer}.mlp.gate_proj.scales")
self.qlinear.scales[:, N:] = get_slice(f, f"model.layers.{layer}.mlp.up_proj.scales")
torch.testing.assert_close(f.get_tensor(f"model.layers.{layer}.mlp.gate_proj.g_idx"), f.get_tensor(f"model.layers.{layer}.mlp.up_proj.g_idx"))
self.qlinear.g_idx[:] = f.get_tensor(f"model.layers.{layer}.mlp.gate_proj.g_idx")
N = self.qlinear.qzeros.shape[1] // 2
self.qlinear.qzeros[:, N:] = get_slice(f, f"model.layers.{layer}.mlp.up_proj.qzeros")
self.qlinear.qzeros[:, :N] = get_slice(f, f"model.layers.{layer}.mlp.gate_proj.qzeros")
elif name == "mlp.down_proj":
self.qlinear.qweight[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.qweight")
self.qlinear.qzeros[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.qzeros")
self.qlinear.scales[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.scales")
self.qlinear.g_idx[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.g_idx")
else:
raise ValueError("Not handled")
# Delete reference to data
self.weight = None
@ -134,7 +240,7 @@ class FastLinear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:
return self.bnb_linear(input)
return self.qlinear(input)
else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
@ -542,12 +648,12 @@ class FlashLlamaModel(torch.nn.Module):
def post_load_weights(self, load_in_8bit: bool = False):
if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx()
for layer in self.layers:
for i, layer in enumerate(self.layers):
layer: FlashLlamaLayer
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
layer.mlp.down_proj.prepare_weights(load_in_8bit)
layer.self_attn.query_key_value.prepare_weights(i, "self_attn.query_key_value", load_in_8bit)
layer.self_attn.o_proj.prepare_weights(i, "self_attn.o_proj", load_in_8bit)
layer.mlp.gate_up_proj.prepare_weights(i, "mlp.gate_up_proj", load_in_8bit)
layer.mlp.down_proj.prepare_weights(i, "mlp.down_proj", load_in_8bit)
def forward(
self,

View File

@ -28,7 +28,7 @@ tracer = trace.get_tracer(__name__)
class FlashLlama(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=None):
self.past_pad = None
if torch.cuda.is_available():
device = torch.device("cuda")
@ -154,7 +154,7 @@ class FlashLlama(FlashCausalLM):
class FlashLlamaSharded(FlashLlama):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None
):
self.past_pad = None
self.process_group, rank, world_size = initialize_torch_distributed()
@ -177,13 +177,13 @@ class FlashLlamaSharded(FlashLlama):
revision=revision,
)
torch.distributed.barrier(group=self.process_group)
# torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashLlamaForCausalLM(config, process_group=self.process_group)
torch.distributed.barrier(group=self.process_group)
# torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,

View File

@ -0,0 +1,4 @@
from .quantizer import Quantizer
from .fused_attn import QuantLlamaAttention, make_quant_attn
from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused
from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear

View File

@ -0,0 +1,193 @@
#https://github.com/fpgaminer/GPTQ-triton
"""
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
"""
import builtins
import math
import time
from typing import Dict
import triton
class Autotuner(triton.KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
'''
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
'''
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols.")
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
except triton.compiler.OutOfResources:
return (float('inf'), float('inf'), float('inf'))
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2**int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
return decorator
def matmul248_kernel_config_pruner(configs, nargs):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
"""
m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
used = set()
for config in configs:
block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
group_size_m = config.kwargs['GROUP_SIZE_M']
if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
continue
used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
yield triton.Config({
'BLOCK_SIZE_M': block_size_m,
'BLOCK_SIZE_N': block_size_n,
'BLOCK_SIZE_K': block_size_k,
'GROUP_SIZE_M': group_size_m
},
num_stages=config.num_stages,
num_warps=config.num_warps)

View File

@ -0,0 +1,123 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from .quant_linear import *
class QuantLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size,
num_heads,
qkv_proj,
o_proj,
rotary_emb,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = rotary_emb
def _shape(self, tensor, seq_len, bsz):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
is_causal = past_key_value is None
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
past_key_value = (key_states, value_states) if use_cache else None
with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def make_quant_attn(model):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.g_idx = g_idx
qkv_layer.bias = bias
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr(parent, child_name, attn)

View File

@ -0,0 +1,288 @@
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers.models.llama.modeling_llama import LlamaMLP
try:
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config({
'BLOCK_SIZE_M': 256,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4), # 3090
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 16,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4), # 3090
triton.Config({
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=4), # 3090
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 16,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4), # 3090
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4), # 3090
],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""
Computes: C = silu(A * B1) * (A * B2)
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (1, N) float16
zeros is of shape (1, N//8) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
g1_ptrs = g1_ptr + offs_k
g2_ptrs = g2_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales1_ptrs = scales1_ptr + offs_bn[None, :]
scales2_ptrs = scales2_ptr + offs_bn[None, :]
zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g1_idx = tl.load(g1_ptrs)
g2_idx = tl.load(g2_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
zeros1 = (zeros1 + 1)
zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
zeros2 = (zeros2 + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
b2 = tl.load(b2_ptrs)
# Now we need to unpack b (which is N-bit values) into 32-bit values
b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
b1 = (b1 - zeros1) * scales1 # Scale and shift
accumulator1 += tl.dot(a, b1)
b2 = (b2 >> shifter[:, None]) & maxq
b2 = (b2 - zeros2) * scales2
accumulator2 += tl.dot(a, b2)
a_ptrs += BLOCK_SIZE_K
b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g1_ptrs += BLOCK_SIZE_K
g2_ptrs += BLOCK_SIZE_K
accumulator1 = silu(accumulator1)
c = accumulator1 * accumulator2
c = c.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@triton.jit
def silu(x):
return x * tl.sigmoid(x)
except:
print('triton not installed.')
class QuantLlamaMLP(nn.Module):
def __init__(
self,
gate_proj,
down_proj,
up_proj,
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.register_buffer('up_proj_g_idx', up_proj.g_idx)
self.infeatures = gate_proj.infeatures
self.intermediate_size = gate_proj.outfeatures
self.outfeatures = down_proj.outfeatures
self.bits = gate_proj.bits
self.maxq = gate_proj.maxq
self.down_proj = down_proj
def forward(self, x):
return self.down_proj(self.triton_llama_mlp(x))
def triton_llama_mlp(self, x):
with torch.cuda.device(x.device):
out_shape = x.shape[:-1] + (self.intermediate_size, )
x = x.reshape(-1, x.shape[-1])
M, K = x.shape
N = self.intermediate_size
c = torch.empty((M, N), device=x.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales,
self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0),
self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0))
c = c.reshape(out_shape)
return c
def fused2cuda(self):
self.gate_proj_qweight = self.gate_proj_qweight.cuda()
self.gate_proj_scales = self.gate_proj_scales.cuda()
self.gate_proj_qzeros = self.gate_proj_qzeros.cuda()
self.gate_proj_g_idx = self.gate_proj_g_idx.cuda()
self.up_proj_qweight = self.up_proj_qweight.cuda()
self.up_proj_scales = self.up_proj_scales.cuda()
self.up_proj_qzeros = self.up_proj_qzeros.cuda()
self.up_proj_g_idx = self.up_proj_g_idx.cuda()
def fused2cpu(self):
self.gate_proj_qweight = self.gate_proj_qweight.cpu()
self.gate_proj_scales = self.gate_proj_scales.cpu()
self.gate_proj_qzeros = self.gate_proj_qzeros.cpu()
self.gate_proj_g_idx = self.gate_proj_g_idx.cpu()
self.up_proj_qweight = self.up_proj_qweight.cpu()
self.up_proj_scales = self.up_proj_scales.cpu()
self.up_proj_qzeros = self.up_proj_qzeros.cpu()
self.up_proj_g_idx = self.up_proj_g_idx.cpu()
def make_fused_mlp(m, parent_name=''):
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP):
setattr(m, name, child)
return m
def autotune_warmup_fused(model):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, QuantLlamaMLP):
continue
k = m.infeatures
n = m.intermediate_size
m.fused2cuda()
if (k, n) not in kn_values:
kn_values[(k, n)] = m
print(f'Found {len(kn_values)} unique fused mlp KN values.')
print('Warming up autotune cache ...')
with torch.no_grad():
for m in tqdm(range(0, 12)):
m = 2**m # [1, 2048]
for (k, n), (modules) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
modules.triton_llama_mlp(a)
for (k, n), (modules) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
modules.fused2cpu()
del kn_values

View File

@ -0,0 +1,423 @@
import math
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(configs=[
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 256,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True)
@triton.jit
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for n in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
except:
print('trioton not installed.')
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
return output
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32).cuda())
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32).cuda())
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16).cuda())
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32).cuda())
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures, )
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def make_quant_linear(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
for name1, child in module.named_children():
make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
def autotune_warmup_linear(model, transpose=False):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, QuantLinear):
continue
k = m.infeatures
n = m.outfeatures
if (k, n) not in kn_values:
kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
print(f'Found {len(kn_values)} unique KN Linear values.')
print('Warming up autotune cache ...')
with torch.no_grad():
for m in tqdm(range(0, 12)):
m = 2**m # [1, 2048]
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
if transpose:
a = torch.randn(m, n, dtype=torch.float16, device='cuda')
transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values

View File

@ -0,0 +1,127 @@
import numpy as np
import torch
import torch.nn as nn
import math
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
self.scale = torch.zeros_like(self.scale)
def _quantize(self, x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return self._quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)

View File

@ -0,0 +1,423 @@
import math
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(configs=[
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 256,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True)
@triton.jit
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for n in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
except:
print('trioton not installed.')
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
return output
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures, )
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def make_quant_linear(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
for name1, child in module.named_children():
make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
def autotune_warmup_linear(model, transpose=False):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, QuantLinear):
continue
k = m.infeatures
n = m.outfeatures
if (k, n) not in kn_values:
kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
print(f'Found {len(kn_values)} unique KN Linear values.')
print('Warming up autotune cache ...')
with torch.no_grad():
for m in tqdm(range(0, 12)):
m = 2**m # [1, 2048]
for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device='cuda')
matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
if transpose:
a = torch.randn(m, n, dtype=torch.float16, device='cuda')
transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values

View File

@ -100,14 +100,14 @@ def serve(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: bool,
quantize: Optional[str],
uds_path: Path,
):
async def serve_inner(
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: bool = False,
quantize: Optional[str] = None,
):
unix_socket_template = "unix://{}-{}"
if sharded:

View File

@ -4,6 +4,13 @@ import torch
from datetime import timedelta
class Fake:
def size(self):
return int(os.getenv("WORLD_SIZE", "1"))
def rank(self):
return int(os.getenv("RANK", "0"))
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
@ -33,3 +40,4 @@ def initialize_torch_distributed():
)
return torch.distributed.group.WORLD, rank, world_size
# return Fake(), rank, world_size