1018 lines
31 KiB
Python
1018 lines
31 KiB
Python
import time
|
|
import torch.nn as nn
|
|
import math
|
|
import json
|
|
import os
|
|
import torch
|
|
import transformers
|
|
|
|
from texttable import Texttable
|
|
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
|
from huggingface_hub import HfApi
|
|
from accelerate import init_empty_weights
|
|
from text_generation_server.utils import initialize_torch_distributed, Weights
|
|
from text_generation_server.utils.hub import weight_files
|
|
from text_generation_server.layers.gptq import QuantLinear
|
|
from loguru import logger
|
|
from typing import Optional
|
|
from text_generation_server.layers.gptq.utils import torch_snr_error
|
|
|
|
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
|
|
|
DEV = torch.device("cuda:0")
|
|
|
|
|
|
class Quantizer(nn.Module):
|
|
def __init__(self, shape=1):
|
|
super(Quantizer, self).__init__()
|
|
self.register_buffer("maxq", torch.tensor(0))
|
|
self.register_buffer("scale", torch.zeros(shape))
|
|
self.register_buffer("zero", torch.zeros(shape))
|
|
|
|
def configure(
|
|
self,
|
|
bits,
|
|
perchannel=False,
|
|
sym=True,
|
|
mse=False,
|
|
norm=2.4,
|
|
grid=100,
|
|
maxshrink=0.8,
|
|
trits=False,
|
|
):
|
|
self.maxq = torch.tensor(2**bits - 1)
|
|
self.perchannel = perchannel
|
|
self.sym = sym
|
|
self.mse = mse
|
|
self.norm = norm
|
|
self.grid = grid
|
|
self.maxshrink = maxshrink
|
|
if trits:
|
|
self.maxq = torch.tensor(-1)
|
|
self.scale = torch.zeros_like(self.scale)
|
|
|
|
def _quantize(self, x, scale, zero, maxq):
|
|
if maxq < 0:
|
|
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
|
|
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
|
return scale * (q - zero)
|
|
|
|
def find_params(self, x, weight=False):
|
|
dev = x.device
|
|
self.maxq = self.maxq.to(dev)
|
|
|
|
shape = x.shape
|
|
if self.perchannel:
|
|
if weight:
|
|
x = x.flatten(1)
|
|
else:
|
|
if len(shape) == 4:
|
|
x = x.permute([1, 0, 2, 3])
|
|
x = x.flatten(1)
|
|
if len(shape) == 3:
|
|
x = x.reshape((-1, shape[-1])).t()
|
|
if len(shape) == 2:
|
|
x = x.t()
|
|
else:
|
|
x = x.flatten().unsqueeze(0)
|
|
|
|
tmp = torch.zeros(x.shape[0], device=dev)
|
|
xmin = torch.minimum(x.min(1)[0], tmp)
|
|
xmax = torch.maximum(x.max(1)[0], tmp)
|
|
|
|
if self.sym:
|
|
xmax = torch.maximum(torch.abs(xmin), xmax)
|
|
tmp = xmin < 0
|
|
if torch.any(tmp):
|
|
xmin[tmp] = -xmax[tmp]
|
|
tmp = (xmin == 0) & (xmax == 0)
|
|
xmin[tmp] = -1
|
|
xmax[tmp] = +1
|
|
|
|
if self.maxq < 0:
|
|
self.scale = xmax
|
|
self.zero = xmin
|
|
else:
|
|
self.scale = (xmax - xmin) / self.maxq
|
|
if self.sym:
|
|
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
|
else:
|
|
self.zero = torch.round(-xmin / self.scale)
|
|
|
|
if self.mse:
|
|
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
|
for i in range(int(self.maxshrink * self.grid)):
|
|
p = 1 - i / self.grid
|
|
xmin1 = p * xmin
|
|
xmax1 = p * xmax
|
|
scale1 = (xmax1 - xmin1) / self.maxq
|
|
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
|
q = self._quantize(
|
|
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
|
|
)
|
|
q -= x
|
|
q.abs_()
|
|
q.pow_(self.norm)
|
|
err = torch.sum(q, 1)
|
|
tmp = err < best
|
|
if torch.any(tmp):
|
|
best[tmp] = err[tmp]
|
|
self.scale[tmp] = scale1[tmp]
|
|
self.zero[tmp] = zero1[tmp]
|
|
if not self.perchannel:
|
|
if weight:
|
|
tmp = shape[0]
|
|
else:
|
|
tmp = shape[1] if len(shape) != 3 else shape[2]
|
|
self.scale = self.scale.repeat(tmp)
|
|
self.zero = self.zero.repeat(tmp)
|
|
|
|
if weight:
|
|
shape = [-1] + [1] * (len(shape) - 1)
|
|
self.scale = self.scale.reshape(shape)
|
|
self.zero = self.zero.reshape(shape)
|
|
return
|
|
if len(shape) == 4:
|
|
self.scale = self.scale.reshape((1, -1, 1, 1))
|
|
self.zero = self.zero.reshape((1, -1, 1, 1))
|
|
if len(shape) == 3:
|
|
self.scale = self.scale.reshape((1, 1, -1))
|
|
self.zero = self.zero.reshape((1, 1, -1))
|
|
if len(shape) == 2:
|
|
self.scale = self.scale.unsqueeze(0)
|
|
self.zero = self.zero.unsqueeze(0)
|
|
|
|
def quantize(self, x):
|
|
if self.ready():
|
|
return self._quantize(x, self.scale, self.zero, self.maxq)
|
|
|
|
return x
|
|
|
|
def enabled(self):
|
|
return self.maxq > 0
|
|
|
|
def ready(self):
|
|
return torch.all(self.scale != 0)
|
|
|
|
|
|
class GPTQ:
|
|
def __init__(self, layer, observe=False):
|
|
self.layer = layer
|
|
self.dev = self.layer.weight.device
|
|
W = layer.weight.data.clone()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
W = W.flatten(1)
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
W = W.t()
|
|
self.rows = W.shape[0]
|
|
self.columns = W.shape[1]
|
|
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
|
self.nsamples = 0
|
|
self.quantizer = Quantizer()
|
|
self.observe = observe
|
|
|
|
def add_batch(self, inp, out):
|
|
# Hessian H = 2 X XT + λ I
|
|
if self.observe:
|
|
self.inp1 = inp
|
|
self.out1 = out
|
|
else:
|
|
self.inp1 = None
|
|
self.out1 = None
|
|
|
|
if len(inp.shape) == 2:
|
|
inp = inp.unsqueeze(0)
|
|
tmp = inp.shape[0]
|
|
if isinstance(self.layer, nn.Linear) or isinstance(
|
|
self.layer, transformers.Conv1D
|
|
):
|
|
if len(inp.shape) == 3:
|
|
inp = inp.reshape((-1, inp.shape[-1]))
|
|
inp = inp.t()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
unfold = nn.Unfold(
|
|
self.layer.kernel_size,
|
|
dilation=self.layer.dilation,
|
|
padding=self.layer.padding,
|
|
stride=self.layer.stride,
|
|
)
|
|
inp = unfold(inp)
|
|
inp = inp.permute([1, 0, 2])
|
|
inp = inp.flatten(1)
|
|
self.H *= self.nsamples / (self.nsamples + tmp)
|
|
self.nsamples += tmp
|
|
# inp = inp.float()
|
|
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
|
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
|
self.H += inp.matmul(inp.t())
|
|
|
|
def print_loss(self, name, q_weight, weight_error, timecost):
|
|
table = Texttable()
|
|
length = 28
|
|
name = (
|
|
(name + " " * (length - len(name)))
|
|
if len(name) <= length
|
|
else name[:length]
|
|
)
|
|
|
|
table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
|
|
|
|
# assign weight
|
|
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
|
|
self.layer.weight.data.dtype
|
|
)
|
|
|
|
if self.inp1 is not None:
|
|
# quantize input to int8
|
|
quantizer = Quantizer()
|
|
quantizer.configure(8, perchannel=False, sym=True, mse=False)
|
|
quantizer.find_params(self.inp1)
|
|
q_in = quantizer.quantize(self.inp1).type(torch.float16)
|
|
q_out = self.layer(q_in)
|
|
|
|
# get kinds of SNR
|
|
q_SNR = torch_snr_error(q_out, self.out1).item()
|
|
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
|
|
else:
|
|
q_SNR = "-"
|
|
fp_SNR = "-"
|
|
|
|
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
|
|
print(table.draw().split("\n")[-2])
|
|
|
|
def fasterquant(
|
|
self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
|
|
):
|
|
self.layer.to(self.dev)
|
|
|
|
W = self.layer.weight.data.clone()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
W = W.flatten(1)
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
W = W.t()
|
|
W = W.float()
|
|
|
|
tick = time.time()
|
|
|
|
if not self.quantizer.ready():
|
|
self.quantizer.find_params(W, weight=True)
|
|
|
|
H = self.H
|
|
if not self.observe:
|
|
del self.H
|
|
dead = torch.diag(H) == 0
|
|
H[dead, dead] = 1
|
|
W[:, dead] = 0
|
|
|
|
if act_order:
|
|
perm = torch.argsort(torch.diag(H), descending=True)
|
|
W = W[:, perm]
|
|
H = H[perm][:, perm]
|
|
|
|
Losses = torch.zeros_like(W)
|
|
Q = torch.zeros_like(W)
|
|
|
|
damp = percdamp * torch.mean(torch.diag(H))
|
|
diag = torch.arange(self.columns, device=self.dev)
|
|
H[diag, diag] += damp
|
|
H = torch.linalg.cholesky(H)
|
|
H = torch.cholesky_inverse(H)
|
|
try:
|
|
H = torch.linalg.cholesky(H, upper=True)
|
|
except Exception:
|
|
# Addition because Falcon fails on h_to_4h
|
|
H = torch.linalg.cholesky(
|
|
H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
|
|
)
|
|
Hinv = H
|
|
|
|
g_idx = []
|
|
scale = []
|
|
zero = []
|
|
now_idx = 1
|
|
|
|
for i1 in range(0, self.columns, blocksize):
|
|
i2 = min(i1 + blocksize, self.columns)
|
|
count = i2 - i1
|
|
|
|
W1 = W[:, i1:i2].clone()
|
|
Q1 = torch.zeros_like(W1)
|
|
Err1 = torch.zeros_like(W1)
|
|
Losses1 = torch.zeros_like(W1)
|
|
Hinv1 = Hinv[i1:i2, i1:i2]
|
|
|
|
for i in range(count):
|
|
w = W1[:, i]
|
|
d = Hinv1[i, i]
|
|
|
|
if groupsize != -1:
|
|
if (i1 + i) % groupsize == 0:
|
|
self.quantizer.find_params(
|
|
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
|
|
)
|
|
|
|
if ((i1 + i) // groupsize) - now_idx == -1:
|
|
scale.append(self.quantizer.scale)
|
|
zero.append(self.quantizer.zero)
|
|
now_idx += 1
|
|
|
|
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
|
Q1[:, i] = q
|
|
Losses1[:, i] = (w - q) ** 2 / d**2
|
|
|
|
err1 = (w - q) / d
|
|
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
|
Err1[:, i] = err1
|
|
|
|
Q[:, i1:i2] = Q1
|
|
Losses[:, i1:i2] = Losses1 / 2
|
|
|
|
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
|
|
|
torch.cuda.synchronize()
|
|
error = torch.sum(Losses).item()
|
|
|
|
groupsize = groupsize if groupsize != -1 else self.columns
|
|
g_idx = [i // groupsize for i in range(self.columns)]
|
|
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
|
if act_order:
|
|
invperm = torch.argsort(perm)
|
|
Q = Q[:, invperm]
|
|
g_idx = g_idx[invperm]
|
|
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
Q = Q.t()
|
|
|
|
self.print_loss(
|
|
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
|
|
)
|
|
|
|
if scale == []:
|
|
scale.append(self.quantizer.scale)
|
|
zero.append(self.quantizer.zero)
|
|
scale = torch.cat(scale, dim=1)
|
|
zero = torch.cat(zero, dim=1)
|
|
return scale, zero, g_idx, error
|
|
|
|
def free(self):
|
|
self.inp1 = None
|
|
self.out1 = None
|
|
self.H = None
|
|
self.Losses = None
|
|
self.Trace = None
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
|
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
|
)
|
|
except Exception:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
|
)
|
|
|
|
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
|
|
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
return trainloader, testenc
|
|
|
|
|
|
def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
|
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
|
)
|
|
except Exception:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
|
)
|
|
|
|
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
|
|
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
return trainloader, testenc
|
|
|
|
|
|
def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset(
|
|
"allenai/c4",
|
|
"allenai--c4",
|
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
|
split="train",
|
|
use_auth_token=False,
|
|
)
|
|
valdata = load_dataset(
|
|
"allenai/c4",
|
|
"allenai--c4",
|
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
|
split="validation",
|
|
use_auth_token=False,
|
|
)
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
|
)
|
|
except Exception:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
|
)
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
while True:
|
|
i = random.randint(0, len(traindata) - 1)
|
|
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
|
if trainenc.input_ids.shape[1] >= seqlen:
|
|
break
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
|
|
import random
|
|
|
|
random.seed(0)
|
|
valenc = []
|
|
for _ in range(256):
|
|
while True:
|
|
i = random.randint(0, len(valdata) - 1)
|
|
tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
|
|
if tmp.input_ids.shape[1] >= seqlen:
|
|
break
|
|
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
valenc.append(tmp.input_ids[:, i:j])
|
|
valenc = torch.hstack(valenc)
|
|
|
|
class TokenizerWrapper:
|
|
def __init__(self, input_ids):
|
|
self.input_ids = input_ids
|
|
|
|
valenc = TokenizerWrapper(valenc)
|
|
|
|
return trainloader, valenc
|
|
|
|
|
|
def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
|
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
|
)
|
|
except Exception:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
|
)
|
|
|
|
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
|
|
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
return trainloader, testenc
|
|
|
|
|
|
def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
|
|
from datasets import load_dataset
|
|
|
|
traindata = load_dataset(
|
|
"allenai/c4",
|
|
"allenai--c4",
|
|
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
|
split="train",
|
|
)
|
|
valdata = load_dataset(
|
|
"allenai/c4",
|
|
"allenai--c4",
|
|
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
|
split="validation",
|
|
)
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
|
)
|
|
except Exception:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
|
)
|
|
|
|
import random
|
|
|
|
random.seed(seed)
|
|
trainloader = []
|
|
for _ in range(nsamples):
|
|
while True:
|
|
i = random.randint(0, len(traindata) - 1)
|
|
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
|
if trainenc.input_ids.shape[1] >= seqlen:
|
|
break
|
|
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
j = i + seqlen
|
|
inp = trainenc.input_ids[:, i:j]
|
|
tar = inp.clone()
|
|
tar[:, :-1] = -100
|
|
trainloader.append((inp, tar))
|
|
|
|
valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
|
|
valenc = valenc.input_ids[:, : (256 * seqlen)]
|
|
|
|
class TokenizerWrapper:
|
|
def __init__(self, input_ids):
|
|
self.input_ids = input_ids
|
|
|
|
valenc = TokenizerWrapper(valenc)
|
|
|
|
return trainloader, valenc
|
|
|
|
|
|
def get_loaders(
|
|
name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
|
|
):
|
|
if "wikitext2" in name:
|
|
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
|
|
if "ptb" in name:
|
|
if "new" in name:
|
|
return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)
|
|
return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)
|
|
if "c4" in name:
|
|
if "new" in name:
|
|
return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)
|
|
return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)
|
|
|
|
|
|
def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
|
|
# Skip last lm_head linear
|
|
# Need isintance Falcon is inheriting Linear.
|
|
if isinstance(module, layers) and "lm_head" not in name:
|
|
return {name: module}
|
|
res = {}
|
|
for name1, child in module.named_children():
|
|
res.update(
|
|
find_layers(
|
|
child, layers=layers, name=name + "." + name1 if name != "" else name1
|
|
)
|
|
)
|
|
return res
|
|
|
|
|
|
@torch.no_grad()
|
|
def sequential(
|
|
model,
|
|
dataloader,
|
|
dev,
|
|
nsamples,
|
|
bits,
|
|
groupsize,
|
|
*,
|
|
hooks,
|
|
percdamp=0.01,
|
|
sym: bool = False,
|
|
act_order: bool = False,
|
|
):
|
|
print("Starting ...")
|
|
|
|
use_cache = model.config.use_cache
|
|
model.config.use_cache = False
|
|
try:
|
|
layers = model.model.layers
|
|
prefix = "model.layers"
|
|
except Exception:
|
|
layers = model.transformer.h
|
|
prefix = "transformer.h"
|
|
|
|
dtype = next(iter(model.parameters())).dtype
|
|
inps = torch.zeros(
|
|
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
|
|
)
|
|
|
|
cache = {"i": 0}
|
|
extra = {}
|
|
|
|
class Catcher(nn.Module):
|
|
def __init__(self, module):
|
|
super().__init__()
|
|
self.module = module
|
|
|
|
def forward(self, inp, **kwargs):
|
|
inps[cache["i"]] = inp
|
|
cache["i"] += 1
|
|
extra.update(kwargs.copy())
|
|
raise ValueError
|
|
|
|
layers[0] = Catcher(layers[0])
|
|
for batch in dataloader:
|
|
try:
|
|
model(batch[0].cuda())
|
|
except ValueError:
|
|
pass
|
|
layers[0] = layers[0].module
|
|
|
|
# layers[0] = layers[0].cpu()
|
|
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
|
# model.model.norm = model.model.norm.cpu()
|
|
torch.cuda.empty_cache()
|
|
for hook in hooks:
|
|
hook.remove()
|
|
|
|
outs = torch.zeros_like(inps)
|
|
|
|
extra = {
|
|
k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
|
|
}
|
|
|
|
print("Ready.")
|
|
|
|
quantizers = {}
|
|
for i in range(len(layers)):
|
|
print(f"Quantizing layer {i+1}/{len(layers)}..")
|
|
print("+------------------+--------------+------------+-----------+-------+")
|
|
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
|
print("+==================+==============+============+===========+=======+")
|
|
|
|
layer = layers[i]
|
|
layer.load()
|
|
full = find_layers(layer)
|
|
sequential = [list(full.keys())]
|
|
|
|
for names in sequential:
|
|
subset = {n: full[n] for n in names}
|
|
gptq = {}
|
|
for name in subset:
|
|
gptq[name] = GPTQ(subset[name])
|
|
gptq[name].quantizer.configure(
|
|
bits, perchannel=True, sym=sym, mse=False
|
|
)
|
|
pass
|
|
|
|
def add_batch(name):
|
|
nonlocal gptq
|
|
|
|
def tmp(_, inp, out):
|
|
gptq[name].add_batch(inp[0].data, out.data)
|
|
|
|
return tmp
|
|
|
|
handles = []
|
|
for name in subset:
|
|
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
|
for j in range(nsamples):
|
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
|
for h in handles:
|
|
h.remove()
|
|
|
|
for name in subset:
|
|
scale, zero, g_idx, error = gptq[name].fasterquant(
|
|
percdamp=percdamp,
|
|
groupsize=groupsize,
|
|
act_order=act_order,
|
|
name=name,
|
|
)
|
|
quantizers[f"{prefix}.{i}.{name}"] = (
|
|
gptq[name].quantizer.cpu(),
|
|
scale.cpu(),
|
|
zero.cpu(),
|
|
g_idx.cpu(),
|
|
bits,
|
|
groupsize,
|
|
)
|
|
|
|
gptq[name].free()
|
|
|
|
for j in range(nsamples):
|
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
|
|
|
layer.unload()
|
|
del layer
|
|
del gptq
|
|
torch.cuda.empty_cache()
|
|
|
|
inps, outs = outs, inps
|
|
print("+------------------+--------------+------------+-----------+-------+")
|
|
print("\n")
|
|
|
|
model.config.use_cache = use_cache
|
|
|
|
return quantizers
|
|
|
|
|
|
def make_quant_linear(module, names, bits, groupsize, name=""):
|
|
if isinstance(module, QuantLinear):
|
|
return
|
|
for attr in dir(module):
|
|
tmp = getattr(module, attr)
|
|
name1 = name + "." + attr if name != "" else attr
|
|
if name1 in names:
|
|
delattr(module, attr)
|
|
setattr(
|
|
module,
|
|
attr,
|
|
QuantLinear.new(
|
|
bits,
|
|
groupsize,
|
|
tmp.in_features,
|
|
tmp.out_features,
|
|
tmp.bias is not None,
|
|
),
|
|
)
|
|
for name1, child in module.named_children():
|
|
make_quant_linear(
|
|
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
|
|
)
|
|
|
|
|
|
# TODO: perform packing on GPU
|
|
def pack(model, quantizers, bits, groupsize):
|
|
layers = find_layers(model)
|
|
layers = {n: layers[n] for n in quantizers}
|
|
make_quant_linear(model, quantizers, bits, groupsize)
|
|
qlayers = find_layers(model, (QuantLinear,))
|
|
print("Packing ...")
|
|
for name in qlayers:
|
|
print(name)
|
|
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
|
|
qlayers[name].pack(layers[name], scale, zero, g_idx)
|
|
print("Done.")
|
|
return model
|
|
|
|
|
|
def setdeepattr(module, full_name, tensor):
|
|
current = module
|
|
tokens = full_name.split(".")
|
|
for token in tokens[:-1]:
|
|
current = getattr(current, token)
|
|
setattr(current, tokens[-1], tensor)
|
|
|
|
|
|
def getdeepattr(module, full_name):
|
|
current = module
|
|
tokens = full_name.split(".")
|
|
for token in tokens:
|
|
current = getattr(current, token)
|
|
return current
|
|
|
|
|
|
def load_weights_pre_hook(module_name, weights, recursive=False):
|
|
def inner(module, args):
|
|
print(f"Pre hook {module_name}")
|
|
local_params = {}
|
|
for k, v in module.named_parameters():
|
|
if not recursive and k.count(".") != 1:
|
|
continue
|
|
local_params[k] = v
|
|
for k, v in module.named_buffers():
|
|
if not recursive and k.count(".") != 1:
|
|
continue
|
|
local_params[k] = v
|
|
|
|
for local_param in local_params:
|
|
current_tensor = getdeepattr(module, local_param)
|
|
if current_tensor.device == torch.device("meta"):
|
|
# print(f"Loading {local_param}")
|
|
if module_name:
|
|
tensor_name = f"{module_name}.{local_param}"
|
|
else:
|
|
tensor_name = local_param
|
|
tensor = weights.get_tensor(tensor_name)
|
|
setdeepattr(module, local_param, nn.Parameter(tensor))
|
|
else:
|
|
tensor = current_tensor.to(device=torch.device("cuda:0"))
|
|
if current_tensor.requires_grad:
|
|
tensor = nn.Parameter(tensor)
|
|
setdeepattr(module, local_param, tensor)
|
|
|
|
return inner
|
|
|
|
|
|
def load_weights_post_hook(module_name, weights, recursive=False):
|
|
def inner(module, args, output):
|
|
print(f"Post hook {module_name}")
|
|
local_params = {}
|
|
for k, v in module.named_parameters():
|
|
if not recursive and k.count(".") != 1:
|
|
continue
|
|
local_params[k] = v
|
|
for k, v in module.named_buffers():
|
|
if not recursive and k.count(".") != 1:
|
|
continue
|
|
local_params[k] = v
|
|
for local_param in local_params:
|
|
# print(f"Unloading {local_param}")
|
|
current_tensor = getdeepattr(module, local_param)
|
|
setdeepattr(
|
|
module,
|
|
local_param,
|
|
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
|
|
)
|
|
return output
|
|
|
|
return inner
|
|
|
|
|
|
def quantize(
|
|
model_id: str,
|
|
bits: int,
|
|
groupsize: int,
|
|
output_dir: str,
|
|
revision: str,
|
|
trust_remote_code: bool,
|
|
upload_to_model_id: Optional[str],
|
|
percdamp: float,
|
|
act_order: bool,
|
|
sym: bool,
|
|
):
|
|
print("loading model")
|
|
config = AutoConfig.from_pretrained(
|
|
model_id,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
with init_empty_weights():
|
|
model = AutoModelForCausalLM.from_config(
|
|
config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
|
|
)
|
|
model = model.eval()
|
|
|
|
print("LOADED model")
|
|
files = weight_files(model_id, revision, extension=".safetensors")
|
|
process_group, _, _ = initialize_torch_distributed()
|
|
weights = Weights(
|
|
files,
|
|
device=torch.device("cuda:0"),
|
|
dtype=torch.float16,
|
|
process_group=process_group,
|
|
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
|
weights_loader=DefaultWeightsLoader(UnquantizedWeight),
|
|
)
|
|
hooks = []
|
|
for name, module in model.named_modules():
|
|
|
|
def load(module, name):
|
|
def _load():
|
|
load_weights_pre_hook(name, weights, recursive=True)(module, None)
|
|
|
|
return _load
|
|
|
|
def unload(module, name):
|
|
def _unload():
|
|
load_weights_post_hook(name, weights, recursive=True)(
|
|
module, None, None
|
|
)
|
|
|
|
return _unload
|
|
|
|
module.load = load(module, name)
|
|
module.unload = unload(module, name)
|
|
hooks.append(
|
|
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
|
|
)
|
|
hooks.append(
|
|
module.register_forward_hook(load_weights_post_hook(name, weights))
|
|
)
|
|
model.seqlen = 2048
|
|
|
|
dataset = "wikitext2"
|
|
nsamples = 128
|
|
seed = None
|
|
|
|
dataloader, testloader = get_loaders(
|
|
dataset,
|
|
nsamples=nsamples,
|
|
seed=seed,
|
|
model_id=model_id,
|
|
seqlen=model.seqlen,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
tick = time.time()
|
|
quantizers = sequential(
|
|
model,
|
|
dataloader,
|
|
DEV,
|
|
nsamples,
|
|
bits,
|
|
groupsize,
|
|
percdamp=percdamp,
|
|
act_order=act_order,
|
|
hooks=hooks,
|
|
sym=sym,
|
|
)
|
|
print(time.time() - tick)
|
|
|
|
pack(model, quantizers, bits, groupsize)
|
|
from safetensors.torch import save_file
|
|
from transformers.modeling_utils import shard_checkpoint
|
|
|
|
state_dict = model.state_dict()
|
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
|
|
|
max_shard_size = "10GB"
|
|
shards, index = shard_checkpoint(
|
|
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
|
)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
for shard_file, shard in shards.items():
|
|
save_file(
|
|
shard,
|
|
os.path.join(output_dir, shard_file),
|
|
metadata={
|
|
"format": "pt",
|
|
"quantized": "gptq",
|
|
"origin": "text-generation-inference",
|
|
},
|
|
)
|
|
if index is None:
|
|
path_to_weights = os.path.join(output_dir, "model.safetensors")
|
|
logger.info(f"Model weights saved in {path_to_weights}")
|
|
else:
|
|
save_index_file = "model.safetensors.index.json"
|
|
save_index_file = os.path.join(output_dir, save_index_file)
|
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
|
f.write(content)
|
|
logger.info(
|
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
|
f"index located at {save_index_file}."
|
|
)
|
|
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
|
config.quantization_config = {
|
|
"bits": bits,
|
|
"group_size": groupsize,
|
|
"damp_percent": percdamp,
|
|
"desc_act": act_order,
|
|
"static_groups": False,
|
|
"sym": sym,
|
|
"quant_method": "gptq",
|
|
}
|
|
config.save_pretrained(output_dir)
|
|
logger.info("Saved config")
|
|
logger.info("Saving tokenizer")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id, trust_remote_code=trust_remote_code
|
|
)
|
|
tokenizer.save_pretrained(output_dir)
|
|
logger.info("Saved tokenizer")
|
|
|
|
if upload_to_model_id:
|
|
api = HfApi()
|
|
|
|
api.upload_folder(
|
|
folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
|
|
)
|