feat(server): Reworking the quantization script so it's still universal (not llama specific) (#587)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).


# What does this PR do?

Reworking the quantization script so it's still universal (not llama
specific)

but should work on more configurations (no need for 2 GPUs, less RAM
usage).

Still need to investigate the potential differences in quantization
results.


<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2023-07-18 12:19:05 +02:00 committed by GitHub
parent 44acf72a73
commit 4d38a1c4ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 12 deletions

View File

@ -194,6 +194,8 @@ def quantize(
percdamp: float = 0.01, percdamp: float = 0.01,
act_order: bool = False, act_order: bool = False,
): ):
if revision is None:
revision = "main"
download_weights( download_weights(
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
@ -207,6 +209,7 @@ def quantize(
bits=4, bits=4,
groupsize=128, groupsize=128,
output_dir=output_dir, output_dir=output_dir,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id, upload_to_model_id=upload_to_model_id,
percdamp=percdamp, percdamp=percdamp,

View File

@ -13,6 +13,9 @@ import transformers
from huggingface_hub import HfApi from huggingface_hub import HfApi
import numpy as np import numpy as np
import torch import torch
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
@ -38,7 +41,6 @@ class Quantizer(nn.Module):
maxshrink=0.8, maxshrink=0.8,
trits=False, trits=False,
): ):
self.maxq = torch.tensor(2**bits - 1) self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel self.perchannel = perchannel
self.sym = sym self.sym = sym
@ -600,6 +602,8 @@ def sequential(
nsamples, nsamples,
bits, bits,
groupsize, groupsize,
*,
hooks,
percdamp=0.01, percdamp=0.01,
sym: bool = False, sym: bool = False,
act_order: bool = False, act_order: bool = False,
@ -637,7 +641,7 @@ def sequential(
layers[0] = Catcher(layers[0]) layers[0] = Catcher(layers[0])
for batch in dataloader: for batch in dataloader:
try: try:
model(batch[0]) model(batch[0].cuda())
except ValueError: except ValueError:
pass pass
layers[0] = layers[0].module layers[0] = layers[0].module
@ -646,6 +650,8 @@ def sequential(
# model.model.embed_tokens = model.model.embed_tokens.cpu() # model.model.embed_tokens = model.model.embed_tokens.cpu()
# model.model.norm = model.model.norm.cpu() # model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
for hook in hooks:
hook.remove()
outs = torch.zeros_like(inps) outs = torch.zeros_like(inps)
@ -662,10 +668,8 @@ def sequential(
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
print("+==================+==============+============+===========+=======+") print("+==================+==============+============+===========+=======+")
from accelerate.hooks import remove_hook_from_submodules layer = layers[i]
layer.load()
layer = layers[i].to(dev)
remove_hook_from_submodules(layer)
full = find_layers(layer) full = find_layers(layer)
sequential = [list(full.keys())] sequential = [list(full.keys())]
@ -677,6 +681,7 @@ def sequential(
gptq[name].quantizer.configure( gptq[name].quantizer.configure(
bits, perchannel=True, sym=sym, mse=False bits, perchannel=True, sym=sym, mse=False
) )
pass
def add_batch(name): def add_batch(name):
def tmp(_, inp, out): def tmp(_, inp, out):
@ -688,7 +693,6 @@ def sequential(
for name in subset: for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name))) handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(nsamples): for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
for h in handles: for h in handles:
h.remove() h.remove()
@ -714,7 +718,7 @@ def sequential(
for j in range(nsamples): for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
layers[i] = layer.cpu() layer.unload()
del layer del layer
del gptq del gptq
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize):
return model return model
def setdeepattr(module, full_name, tensor):
current = module
tokens = full_name.split(".")
for token in tokens[:-1]:
current = getattr(current, token)
setattr(current, tokens[-1], tensor)
def getdeepattr(module, full_name):
current = module
tokens = full_name.split(".")
for token in tokens:
current = getattr(current, token)
return current
def load_weights_pre_hook(module_name, weights, recursive=False):
def inner(module, args):
print(f"Pre hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
current_tensor = getdeepattr(module, local_param)
if current_tensor.device == torch.device("meta"):
# print(f"Loading {local_param}")
if module_name:
tensor_name = f"{module_name}.{local_param}"
else:
tensor_name = local_param
tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor))
else:
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
)
return inner
def load_weights_post_hook(module_name, weights, recursive=False):
def inner(module, args, output):
print(f"Post hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
# print(f"Unloading {local_param}")
current_tensor = getdeepattr(module, local_param)
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
)
return output
return inner
def quantize( def quantize(
model_id: str, model_id: str,
bits: int, bits: int,
groupsize: int, groupsize: int,
output_dir: str, output_dir: str,
revision: str,
trust_remote_code: bool, trust_remote_code: bool,
upload_to_model_id: Optional[str], upload_to_model_id: Optional[str],
percdamp: float, percdamp: float,
act_order: bool, act_order: bool,
): ):
print("loading model") print("loading model")
model = AutoModelForCausalLM.from_pretrained( config = AutoConfig.from_pretrained(
model_id, model_id,
torch_dtype=torch.float16,
device_map="balanced_low_0",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
model = model.eval()
print("LOADED model") print("LOADED model")
files = weight_files(model_id, revision, extension=".safetensors")
process_group, _, _ = initialize_torch_distributed()
weights = Weights(
files,
device=torch.device("cuda:0"),
dtype=torch.float16,
process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]},
)
hooks = []
for name, module in model.named_modules():
def load(module, name):
def _load():
load_weights_pre_hook(name, weights, recursive=True)(module, None)
return _load
def unload(module, name):
def _unload():
load_weights_post_hook(name, weights, recursive=True)(
module, None, None
)
return _unload
module.load = load(module, name)
module.unload = unload(module, name)
hooks.append(
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
)
hooks.append(
module.register_forward_hook(load_weights_post_hook(name, weights))
)
model.seqlen = 2048 model.seqlen = 2048
dataset = "wikitext2" dataset = "wikitext2"
@ -806,6 +922,7 @@ def quantize(
groupsize, groupsize,
percdamp=percdamp, percdamp=percdamp,
act_order=act_order, act_order=act_order,
hooks=hooks,
) )
print(time.time() - tick) print(time.time() - tick)
@ -858,7 +975,6 @@ def quantize(
logger.info("Saved tokenizer") logger.info("Saved tokenizer")
if upload_to_model_id: if upload_to_model_id:
api = HfApi() api = HfApi()
api.upload_folder( api.upload_folder(