feat(server): Reworking the quantization script so it's still universal (not llama specific) (#587)
but should work on more configurations (no need for 2 GPUs, less RAM usage). # What does this PR do? Reworking the quantization script so it's still universal (not llama specific) but should work on more configurations (no need for 2 GPUs, less RAM usage). Still need to investigate the potential differences in quantization results. <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
44acf72a73
commit
4d38a1c4ad
|
@ -194,6 +194,8 @@ def quantize(
|
||||||
percdamp: float = 0.01,
|
percdamp: float = 0.01,
|
||||||
act_order: bool = False,
|
act_order: bool = False,
|
||||||
):
|
):
|
||||||
|
if revision is None:
|
||||||
|
revision = "main"
|
||||||
download_weights(
|
download_weights(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
@ -207,6 +209,7 @@ def quantize(
|
||||||
bits=4,
|
bits=4,
|
||||||
groupsize=128,
|
groupsize=128,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
upload_to_model_id=upload_to_model_id,
|
upload_to_model_id=upload_to_model_id,
|
||||||
percdamp=percdamp,
|
percdamp=percdamp,
|
||||||
|
|
|
@ -13,6 +13,9 @@ import transformers
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from text_generation_server.utils import initialize_torch_distributed, Weights
|
||||||
|
from text_generation_server.utils.hub import weight_files
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -38,7 +41,6 @@ class Quantizer(nn.Module):
|
||||||
maxshrink=0.8,
|
maxshrink=0.8,
|
||||||
trits=False,
|
trits=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.maxq = torch.tensor(2**bits - 1)
|
self.maxq = torch.tensor(2**bits - 1)
|
||||||
self.perchannel = perchannel
|
self.perchannel = perchannel
|
||||||
self.sym = sym
|
self.sym = sym
|
||||||
|
@ -600,6 +602,8 @@ def sequential(
|
||||||
nsamples,
|
nsamples,
|
||||||
bits,
|
bits,
|
||||||
groupsize,
|
groupsize,
|
||||||
|
*,
|
||||||
|
hooks,
|
||||||
percdamp=0.01,
|
percdamp=0.01,
|
||||||
sym: bool = False,
|
sym: bool = False,
|
||||||
act_order: bool = False,
|
act_order: bool = False,
|
||||||
|
@ -637,7 +641,7 @@ def sequential(
|
||||||
layers[0] = Catcher(layers[0])
|
layers[0] = Catcher(layers[0])
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
try:
|
try:
|
||||||
model(batch[0])
|
model(batch[0].cuda())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
layers[0] = layers[0].module
|
layers[0] = layers[0].module
|
||||||
|
@ -646,6 +650,8 @@ def sequential(
|
||||||
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||||
# model.model.norm = model.model.norm.cpu()
|
# model.model.norm = model.model.norm.cpu()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
outs = torch.zeros_like(inps)
|
outs = torch.zeros_like(inps)
|
||||||
|
|
||||||
|
@ -662,10 +668,8 @@ def sequential(
|
||||||
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
||||||
print("+==================+==============+============+===========+=======+")
|
print("+==================+==============+============+===========+=======+")
|
||||||
|
|
||||||
from accelerate.hooks import remove_hook_from_submodules
|
layer = layers[i]
|
||||||
|
layer.load()
|
||||||
layer = layers[i].to(dev)
|
|
||||||
remove_hook_from_submodules(layer)
|
|
||||||
full = find_layers(layer)
|
full = find_layers(layer)
|
||||||
sequential = [list(full.keys())]
|
sequential = [list(full.keys())]
|
||||||
|
|
||||||
|
@ -677,6 +681,7 @@ def sequential(
|
||||||
gptq[name].quantizer.configure(
|
gptq[name].quantizer.configure(
|
||||||
bits, perchannel=True, sym=sym, mse=False
|
bits, perchannel=True, sym=sym, mse=False
|
||||||
)
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
def add_batch(name):
|
def add_batch(name):
|
||||||
def tmp(_, inp, out):
|
def tmp(_, inp, out):
|
||||||
|
@ -688,7 +693,6 @@ def sequential(
|
||||||
for name in subset:
|
for name in subset:
|
||||||
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
||||||
for j in range(nsamples):
|
for j in range(nsamples):
|
||||||
|
|
||||||
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||||
for h in handles:
|
for h in handles:
|
||||||
h.remove()
|
h.remove()
|
||||||
|
@ -714,7 +718,7 @@ def sequential(
|
||||||
for j in range(nsamples):
|
for j in range(nsamples):
|
||||||
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||||
|
|
||||||
layers[i] = layer.cpu()
|
layer.unload()
|
||||||
del layer
|
del layer
|
||||||
del gptq
|
del gptq
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def setdeepattr(module, full_name, tensor):
|
||||||
|
current = module
|
||||||
|
tokens = full_name.split(".")
|
||||||
|
for token in tokens[:-1]:
|
||||||
|
current = getattr(current, token)
|
||||||
|
setattr(current, tokens[-1], tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def getdeepattr(module, full_name):
|
||||||
|
current = module
|
||||||
|
tokens = full_name.split(".")
|
||||||
|
for token in tokens:
|
||||||
|
current = getattr(current, token)
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights_pre_hook(module_name, weights, recursive=False):
|
||||||
|
def inner(module, args):
|
||||||
|
print(f"Pre hook {module_name}")
|
||||||
|
local_params = {}
|
||||||
|
for k, v in module.named_parameters():
|
||||||
|
if not recursive and k.count(".") != 1:
|
||||||
|
continue
|
||||||
|
local_params[k] = v
|
||||||
|
for k, v in module.named_buffers():
|
||||||
|
if not recursive and k.count(".") != 1:
|
||||||
|
continue
|
||||||
|
local_params[k] = v
|
||||||
|
|
||||||
|
for local_param in local_params:
|
||||||
|
current_tensor = getdeepattr(module, local_param)
|
||||||
|
if current_tensor.device == torch.device("meta"):
|
||||||
|
# print(f"Loading {local_param}")
|
||||||
|
if module_name:
|
||||||
|
tensor_name = f"{module_name}.{local_param}"
|
||||||
|
else:
|
||||||
|
tensor_name = local_param
|
||||||
|
tensor = weights.get_tensor(tensor_name)
|
||||||
|
setdeepattr(module, local_param, nn.Parameter(tensor))
|
||||||
|
else:
|
||||||
|
setdeepattr(
|
||||||
|
module,
|
||||||
|
local_param,
|
||||||
|
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
|
||||||
|
)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights_post_hook(module_name, weights, recursive=False):
|
||||||
|
def inner(module, args, output):
|
||||||
|
print(f"Post hook {module_name}")
|
||||||
|
local_params = {}
|
||||||
|
for k, v in module.named_parameters():
|
||||||
|
if not recursive and k.count(".") != 1:
|
||||||
|
continue
|
||||||
|
local_params[k] = v
|
||||||
|
for k, v in module.named_buffers():
|
||||||
|
if not recursive and k.count(".") != 1:
|
||||||
|
continue
|
||||||
|
local_params[k] = v
|
||||||
|
for local_param in local_params:
|
||||||
|
# print(f"Unloading {local_param}")
|
||||||
|
current_tensor = getdeepattr(module, local_param)
|
||||||
|
setdeepattr(
|
||||||
|
module,
|
||||||
|
local_param,
|
||||||
|
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def quantize(
|
def quantize(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
bits: int,
|
bits: int,
|
||||||
groupsize: int,
|
groupsize: int,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
|
revision: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
upload_to_model_id: Optional[str],
|
upload_to_model_id: Optional[str],
|
||||||
percdamp: float,
|
percdamp: float,
|
||||||
act_order: bool,
|
act_order: bool,
|
||||||
):
|
):
|
||||||
print("loading model")
|
print("loading model")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map="balanced_low_0",
|
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
print("LOADED model")
|
print("LOADED model")
|
||||||
|
files = weight_files(model_id, revision, extension=".safetensors")
|
||||||
|
process_group, _, _ = initialize_torch_distributed()
|
||||||
|
weights = Weights(
|
||||||
|
files,
|
||||||
|
device=torch.device("cuda:0"),
|
||||||
|
dtype=torch.float16,
|
||||||
|
process_group=process_group,
|
||||||
|
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||||
|
)
|
||||||
|
hooks = []
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
|
||||||
|
def load(module, name):
|
||||||
|
def _load():
|
||||||
|
load_weights_pre_hook(name, weights, recursive=True)(module, None)
|
||||||
|
|
||||||
|
return _load
|
||||||
|
|
||||||
|
def unload(module, name):
|
||||||
|
def _unload():
|
||||||
|
load_weights_post_hook(name, weights, recursive=True)(
|
||||||
|
module, None, None
|
||||||
|
)
|
||||||
|
|
||||||
|
return _unload
|
||||||
|
|
||||||
|
module.load = load(module, name)
|
||||||
|
module.unload = unload(module, name)
|
||||||
|
hooks.append(
|
||||||
|
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
|
||||||
|
)
|
||||||
|
hooks.append(
|
||||||
|
module.register_forward_hook(load_weights_post_hook(name, weights))
|
||||||
|
)
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
|
|
||||||
dataset = "wikitext2"
|
dataset = "wikitext2"
|
||||||
|
@ -806,6 +922,7 @@ def quantize(
|
||||||
groupsize,
|
groupsize,
|
||||||
percdamp=percdamp,
|
percdamp=percdamp,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
|
hooks=hooks,
|
||||||
)
|
)
|
||||||
print(time.time() - tick)
|
print(time.time() - tick)
|
||||||
|
|
||||||
|
@ -858,7 +975,6 @@ def quantize(
|
||||||
logger.info("Saved tokenizer")
|
logger.info("Saved tokenizer")
|
||||||
|
|
||||||
if upload_to_model_id:
|
if upload_to_model_id:
|
||||||
|
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
|
|
||||||
api.upload_folder(
|
api.upload_folder(
|
||||||
|
|
Loading…
Reference in New Issue