From 4d38a1c4ad9e262617a3f36e1d01e8c57693b6ef Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 18 Jul 2023 12:19:05 +0200 Subject: [PATCH] feat(server): Reworking the quantization script so it's still universal (not llama specific) (#587) but should work on more configurations (no need for 2 GPUs, less RAM usage). # What does this PR do? Reworking the quantization script so it's still universal (not llama specific) but should work on more configurations (no need for 2 GPUs, less RAM usage). Still need to investigate the potential differences in quantization results. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/cli.py | 3 + .../utils/gptq/quantize.py | 140 ++++++++++++++++-- 2 files changed, 131 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 7a55e919..e74c0331 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -194,6 +194,8 @@ def quantize( percdamp: float = 0.01, act_order: bool = False, ): + if revision is None: + revision = "main" download_weights( model_id=model_id, revision=revision, @@ -207,6 +209,7 @@ def quantize( bits=4, groupsize=128, output_dir=output_dir, + revision=revision, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, percdamp=percdamp, diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index 5a4ed8da..d182456f 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -13,6 +13,9 @@ import transformers from huggingface_hub import HfApi import numpy as np import torch +from accelerate import init_empty_weights +from text_generation_server.utils import initialize_torch_distributed, Weights +from text_generation_server.utils.hub import weight_files from text_generation_server.utils.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional @@ -38,7 +41,6 @@ class Quantizer(nn.Module): maxshrink=0.8, trits=False, ): - self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym @@ -600,6 +602,8 @@ def sequential( nsamples, bits, groupsize, + *, + hooks, percdamp=0.01, sym: bool = False, act_order: bool = False, @@ -637,7 +641,7 @@ def sequential( layers[0] = Catcher(layers[0]) for batch in dataloader: try: - model(batch[0]) + model(batch[0].cuda()) except ValueError: pass layers[0] = layers[0].module @@ -646,6 +650,8 @@ def sequential( # model.model.embed_tokens = model.model.embed_tokens.cpu() # model.model.norm = model.model.norm.cpu() torch.cuda.empty_cache() + for hook in hooks: + hook.remove() outs = torch.zeros_like(inps) @@ -662,10 +668,8 @@ def sequential( print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") print("+==================+==============+============+===========+=======+") - from accelerate.hooks import remove_hook_from_submodules - - layer = layers[i].to(dev) - remove_hook_from_submodules(layer) + layer = layers[i] + layer.load() full = find_layers(layer) sequential = [list(full.keys())] @@ -677,6 +681,7 @@ def sequential( gptq[name].quantizer.configure( bits, perchannel=True, sym=sym, mse=False ) + pass def add_batch(name): def tmp(_, inp, out): @@ -688,7 +693,6 @@ def sequential( for name in subset: handles.append(subset[name].register_forward_hook(add_batch(name))) for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] for h in handles: h.remove() @@ -714,7 +718,7 @@ def sequential( for j in range(nsamples): outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] - layers[i] = layer.cpu() + layer.unload() del layer del gptq torch.cuda.empty_cache() @@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize): return model +def setdeepattr(module, full_name, tensor): + current = module + tokens = full_name.split(".") + for token in tokens[:-1]: + current = getattr(current, token) + setattr(current, tokens[-1], tensor) + + +def getdeepattr(module, full_name): + current = module + tokens = full_name.split(".") + for token in tokens: + current = getattr(current, token) + return current + + +def load_weights_pre_hook(module_name, weights, recursive=False): + def inner(module, args): + print(f"Pre hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + + for local_param in local_params: + current_tensor = getdeepattr(module, local_param) + if current_tensor.device == torch.device("meta"): + # print(f"Loading {local_param}") + if module_name: + tensor_name = f"{module_name}.{local_param}" + else: + tensor_name = local_param + tensor = weights.get_tensor(tensor_name) + setdeepattr(module, local_param, nn.Parameter(tensor)) + else: + setdeepattr( + module, + local_param, + nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), + ) + + return inner + + +def load_weights_post_hook(module_name, weights, recursive=False): + def inner(module, args, output): + print(f"Post hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for local_param in local_params: + # print(f"Unloading {local_param}") + current_tensor = getdeepattr(module, local_param) + setdeepattr( + module, + local_param, + nn.Parameter(current_tensor.to(device=torch.device("cpu"))), + ) + return output + + return inner + + def quantize( model_id: str, bits: int, groupsize: int, output_dir: str, + revision: str, trust_remote_code: bool, upload_to_model_id: Optional[str], percdamp: float, act_order: bool, ): print("loading model") - model = AutoModelForCausalLM.from_pretrained( + config = AutoConfig.from_pretrained( model_id, - torch_dtype=torch.float16, - device_map="balanced_low_0", trust_remote_code=trust_remote_code, ) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16) + model = model.eval() + print("LOADED model") + files = weight_files(model_id, revision, extension=".safetensors") + process_group, _, _ = initialize_torch_distributed() + weights = Weights( + files, + device=torch.device("cuda:0"), + dtype=torch.float16, + process_group=process_group, + aliases={"embed_tokens.weight": ["lm_head.weight"]}, + ) + hooks = [] + for name, module in model.named_modules(): + + def load(module, name): + def _load(): + load_weights_pre_hook(name, weights, recursive=True)(module, None) + + return _load + + def unload(module, name): + def _unload(): + load_weights_post_hook(name, weights, recursive=True)( + module, None, None + ) + + return _unload + + module.load = load(module, name) + module.unload = unload(module, name) + hooks.append( + module.register_forward_pre_hook(load_weights_pre_hook(name, weights)) + ) + hooks.append( + module.register_forward_hook(load_weights_post_hook(name, weights)) + ) model.seqlen = 2048 dataset = "wikitext2" @@ -806,6 +922,7 @@ def quantize( groupsize, percdamp=percdamp, act_order=act_order, + hooks=hooks, ) print(time.time() - tick) @@ -858,7 +975,6 @@ def quantize( logger.info("Saved tokenizer") if upload_to_model_id: - api = HfApi() api.upload_folder(