import argparse import time import numpy as np import torch import torch.nn as nn import math import json import os from texttable import Texttable from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer import transformers from huggingface_hub import HfApi import numpy as np import torch from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files from text_generation_server.utils.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional DEV = torch.device("cuda:0") class Quantizer(nn.Module): def __init__(self, shape=1): super(Quantizer, self).__init__() self.register_buffer("maxq", torch.tensor(0)) self.register_buffer("scale", torch.zeros(shape)) self.register_buffer("zero", torch.zeros(shape)) def configure( self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8, trits=False, ): self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym self.mse = mse self.norm = norm self.grid = grid self.maxshrink = maxshrink if trits: self.maxq = torch.tensor(-1) self.scale = torch.zeros_like(self.scale) def _quantize(self, x, scale, zero, maxq): if maxq < 0: return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) return scale * (q - zero) def find_params(self, x, weight=False): dev = x.device self.maxq = self.maxq.to(dev) shape = x.shape if self.perchannel: if weight: x = x.flatten(1) else: if len(shape) == 4: x = x.permute([1, 0, 2, 3]) x = x.flatten(1) if len(shape) == 3: x = x.reshape((-1, shape[-1])).t() if len(shape) == 2: x = x.t() else: x = x.flatten().unsqueeze(0) tmp = torch.zeros(x.shape[0], device=dev) xmin = torch.minimum(x.min(1)[0], tmp) xmax = torch.maximum(x.max(1)[0], tmp) if self.sym: xmax = torch.maximum(torch.abs(xmin), xmax) tmp = xmin < 0 if torch.any(tmp): xmin[tmp] = -xmax[tmp] tmp = (xmin == 0) & (xmax == 0) xmin[tmp] = -1 xmax[tmp] = +1 if self.maxq < 0: self.scale = xmax self.zero = xmin else: self.scale = (xmax - xmin) / self.maxq if self.sym: self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) else: self.zero = torch.round(-xmin / self.scale) if self.mse: best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid xmin1 = p * xmin xmax1 = p * xmax scale1 = (xmax1 - xmin1) / self.maxq zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero q = self._quantize( x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq ) q -= x q.abs_() q.pow_(self.norm) err = torch.sum(q, 1) tmp = err < best if torch.any(tmp): best[tmp] = err[tmp] self.scale[tmp] = scale1[tmp] self.zero[tmp] = zero1[tmp] if not self.perchannel: if weight: tmp = shape[0] else: tmp = shape[1] if len(shape) != 3 else shape[2] self.scale = self.scale.repeat(tmp) self.zero = self.zero.repeat(tmp) if weight: shape = [-1] + [1] * (len(shape) - 1) self.scale = self.scale.reshape(shape) self.zero = self.zero.reshape(shape) return if len(shape) == 4: self.scale = self.scale.reshape((1, -1, 1, 1)) self.zero = self.zero.reshape((1, -1, 1, 1)) if len(shape) == 3: self.scale = self.scale.reshape((1, 1, -1)) self.zero = self.zero.reshape((1, 1, -1)) if len(shape) == 2: self.scale = self.scale.unsqueeze(0) self.zero = self.zero.unsqueeze(0) def quantize(self, x): if self.ready(): return self._quantize(x, self.scale, self.zero, self.maxq) return x def enabled(self): return self.maxq > 0 def ready(self): return torch.all(self.scale != 0) class GPTQ: def __init__(self, layer, observe=False): self.layer = layer self.dev = self.layer.weight.device W = layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): W = W.t() self.rows = W.shape[0] self.columns = W.shape[1] self.H = torch.zeros((self.columns, self.columns), device=self.dev) self.nsamples = 0 self.quantizer = Quantizer() self.observe = observe def add_batch(self, inp, out): # Hessian H = 2 X XT + λ I if self.observe: self.inp1 = inp self.out1 = out else: self.inp1 = None self.out1 = None if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] if isinstance(self.layer, nn.Linear) or isinstance( self.layer, transformers.Conv1D ): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() if isinstance(self.layer, nn.Conv2d): unfold = nn.Unfold( self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride, ) inp = unfold(inp) inp = inp.permute([1, 0, 2]) inp = inp.flatten(1) self.H *= self.nsamples / (self.nsamples + tmp) self.nsamples += tmp # inp = inp.float() inp = math.sqrt(2 / self.nsamples) * inp.float() # self.H += 2 / self.nsamples * inp.matmul(inp.t()) self.H += inp.matmul(inp.t()) def print_loss(self, name, q_weight, weight_error, timecost): table = Texttable() length = 28 name = ( (name + " " * (length - len(name))) if len(name) <= length else name[:length] ) table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) # assign weight self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( self.layer.weight.data.dtype ) if self.inp1 is not None: # quantize input to int8 quantizer = Quantizer() quantizer.configure(8, perchannel=False, sym=True, mse=False) quantizer.find_params(self.inp1) q_in = quantizer.quantize(self.inp1).type(torch.float16) q_out = self.layer(q_in) # get kinds of SNR q_SNR = torch_snr_error(q_out, self.out1).item() fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() else: q_SNR = "-" fp_SNR = "-" table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) print(table.draw().split("\n")[-2]) def fasterquant( self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name="" ): self.layer.to(self.dev) W = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): W = W.t() W = W.float() tick = time.time() if not self.quantizer.ready(): self.quantizer.find_params(W, weight=True) H = self.H if not self.observe: del self.H dead = torch.diag(H) == 0 H[dead, dead] = 1 W[:, dead] = 0 if act_order: perm = torch.argsort(torch.diag(H), descending=True) W = W[:, perm] H = H[perm][:, perm] Losses = torch.zeros_like(W) Q = torch.zeros_like(W) damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.dev) H[diag, diag] += damp H = torch.linalg.cholesky(H) H = torch.cholesky_inverse(H) try: H = torch.linalg.cholesky(H, upper=True) except Exception: # Addition because Falcon fails on h_to_4h H = torch.linalg.cholesky( H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True ) Hinv = H g_idx = [] scale = [] zero = [] now_idx = 1 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) count = i2 - i1 W1 = W[:, i1:i2].clone() Q1 = torch.zeros_like(W1) Err1 = torch.zeros_like(W1) Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] for i in range(count): w = W1[:, i] d = Hinv1[i, i] if groupsize != -1: if (i1 + i) % groupsize == 0: self.quantizer.find_params( W[:, (i1 + i) : (i1 + i + groupsize)], weight=True ) if ((i1 + i) // groupsize) - now_idx == -1: scale.append(self.quantizer.scale) zero.append(self.quantizer.zero) now_idx += 1 q = self.quantizer.quantize(w.unsqueeze(1)).flatten() Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d**2 err1 = (w - q) / d W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) Err1[:, i] = err1 Q[:, i1:i2] = Q1 Losses[:, i1:i2] = Losses1 / 2 W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) torch.cuda.synchronize() error = torch.sum(Losses).item() groupsize = groupsize if groupsize != -1 else self.columns g_idx = [i // groupsize for i in range(self.columns)] g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) if act_order: invperm = torch.argsort(perm) Q = Q[:, invperm] g_idx = g_idx[invperm] if isinstance(self.layer, transformers.Conv1D): Q = Q.t() self.print_loss( name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) ) if scale == []: scale.append(self.quantizer.scale) zero.append(self.quantizer.zero) scale = torch.cat(scale, dim=1) zero = torch.cat(zero, dim=1) return scale, zero, g_idx, error def free(self): self.inp1 = None self.out1 = None self.H = None self.Losses = None self.Trace = None torch.cuda.empty_cache() def get_wikitext2(nsamples, seed, seqlen, model_id): from datasets import load_dataset traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") import random random.seed(seed) trainloader = [] for _ in range(nsamples): i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = trainenc.input_ids[:, i:j] tar = inp.clone() tar[:, :-1] = -100 trainloader.append((inp, tar)) return trainloader, testenc def get_ptb(nsamples, seed, seqlen, model_id): from datasets import load_dataset traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") from transformers import AutoTokenizer try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") import random random.seed(seed) trainloader = [] for _ in range(nsamples): i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = trainenc.input_ids[:, i:j] tar = inp.clone() tar[:, :-1] = -100 trainloader.append((inp, tar)) return trainloader, testenc def get_c4(nsamples, seed, seqlen, model_id): from datasets import load_dataset traindata = load_dataset( "allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train", use_auth_token=False, ) valdata = load_dataset( "allenai/c4", "allenai--c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation", use_auth_token=False, ) from transformers import AutoTokenizer try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) import random random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(traindata) - 1) trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") if trainenc.input_ids.shape[1] >= seqlen: break i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = trainenc.input_ids[:, i:j] tar = inp.clone() tar[:, :-1] = -100 trainloader.append((inp, tar)) import random random.seed(0) valenc = [] for _ in range(256): while True: i = random.randint(0, len(valdata) - 1) tmp = tokenizer(valdata[i]["text"], return_tensors="pt") if tmp.input_ids.shape[1] >= seqlen: break i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) j = i + seqlen valenc.append(tmp.input_ids[:, i:j]) valenc = torch.hstack(valenc) class TokenizerWrapper: def __init__(self, input_ids): self.input_ids = input_ids valenc = TokenizerWrapper(valenc) return trainloader, valenc def get_ptb_new(nsamples, seed, seqlen, model_id): from datasets import load_dataset traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") from transformers import AutoTokenizer try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") import random random.seed(seed) trainloader = [] for _ in range(nsamples): i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = trainenc.input_ids[:, i:j] tar = inp.clone() tar[:, :-1] = -100 trainloader.append((inp, tar)) return trainloader, testenc def get_c4_new(nsamples, seed, seqlen, model_id): from datasets import load_dataset traindata = load_dataset( "allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train", ) valdata = load_dataset( "allenai/c4", "allenai--c4", data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, split="validation", ) from transformers import AutoTokenizer try: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) except: tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) import random random.seed(seed) trainloader = [] for _ in range(nsamples): while True: i = random.randint(0, len(traindata) - 1) trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") if trainenc.input_ids.shape[1] >= seqlen: break i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = trainenc.input_ids[:, i:j] tar = inp.clone() tar[:, :-1] = -100 trainloader.append((inp, tar)) valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") valenc = valenc.input_ids[:, : (256 * seqlen)] class TokenizerWrapper: def __init__(self, input_ids): self.input_ids = input_ids valenc = TokenizerWrapper(valenc) return trainloader, valenc def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""): if "wikitext2" in name: return get_wikitext2(nsamples, seed, seqlen, model_id) if "ptb" in name: if "new" in name: return get_ptb_new(nsamples, seed, seqlen, model_id) return get_ptb(nsamples, seed, seqlen, model_id) if "c4" in name: if "new" in name: return get_c4_new(nsamples, seed, seqlen, model_id) return get_c4(nsamples, seed, seqlen, model_id) def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""): # Skip last lm_head linear # Need isintance Falcon is inheriting Linear. if isinstance(module, layers) and "lm_head" not in name: return {name: module} res = {} for name1, child in module.named_children(): res.update( find_layers( child, layers=layers, name=name + "." + name1 if name != "" else name1 ) ) return res @torch.no_grad() def sequential( model, dataloader, dev, nsamples, bits, groupsize, *, hooks, percdamp=0.01, sym: bool = False, act_order: bool = False, ): print("Starting ...") use_cache = model.config.use_cache model.config.use_cache = False try: layers = model.model.layers prefix = "model.layers" except Exception: layers = model.transformer.h prefix = "transformer.h" dtype = next(iter(model.parameters())).dtype inps = torch.zeros( (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev ) cache = {"i": 0} extra = {} class Catcher(nn.Module): def __init__(self, module): super().__init__() self.module = module def forward(self, inp, **kwargs): inps[cache["i"]] = inp cache["i"] += 1 extra.update(kwargs.copy()) raise ValueError layers[0] = Catcher(layers[0]) for batch in dataloader: try: model(batch[0].cuda()) except ValueError: pass layers[0] = layers[0].module # layers[0] = layers[0].cpu() # model.model.embed_tokens = model.model.embed_tokens.cpu() # model.model.norm = model.model.norm.cpu() torch.cuda.empty_cache() for hook in hooks: hook.remove() outs = torch.zeros_like(inps) extra = { k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items() } print("Ready.") quantizers = {} for i in range(len(layers)): print(f"Quantizing layer {i+1}/{len(layers)}..") print("+------------------+--------------+------------+-----------+-------+") print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") print("+==================+==============+============+===========+=======+") layer = layers[i] layer.load() full = find_layers(layer) sequential = [list(full.keys())] for names in sequential: subset = {n: full[n] for n in names} gptq = {} for name in subset: gptq[name] = GPTQ(subset[name]) gptq[name].quantizer.configure( bits, perchannel=True, sym=sym, mse=False ) pass def add_batch(name): def tmp(_, inp, out): gptq[name].add_batch(inp[0].data, out.data) return tmp handles = [] for name in subset: handles.append(subset[name].register_forward_hook(add_batch(name))) for j in range(nsamples): outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] for h in handles: h.remove() for name in subset: scale, zero, g_idx, error = gptq[name].fasterquant( percdamp=percdamp, groupsize=groupsize, act_order=act_order, name=name, ) quantizers[f"{prefix}.{i}.{name}"] = ( gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize, ) gptq[name].free() for j in range(nsamples): outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] layer.unload() del layer del gptq torch.cuda.empty_cache() inps, outs = outs, inps print("+------------------+--------------+------------+-----------+-------+") print("\n") model.config.use_cache = use_cache return quantizers def make_quant_linear(module, names, bits, groupsize, name=""): if isinstance(module, QuantLinear): return for attr in dir(module): tmp = getattr(module, attr) name1 = name + "." + attr if name != "" else attr if name1 in names: delattr(module, attr) setattr( module, attr, QuantLinear.new( bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None, ), ) for name1, child in module.named_children(): make_quant_linear( child, names, bits, groupsize, name + "." + name1 if name != "" else name1 ) # TODO: perform packing on GPU def pack(model, quantizers, bits, groupsize): layers = find_layers(model) layers = {n: layers[n] for n in quantizers} make_quant_linear(model, quantizers, bits, groupsize) qlayers = find_layers(model, (QuantLinear,)) print("Packing ...") for name in qlayers: print(name) quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] qlayers[name].pack(layers[name], scale, zero, g_idx) print("Done.") return model def setdeepattr(module, full_name, tensor): current = module tokens = full_name.split(".") for token in tokens[:-1]: current = getattr(current, token) setattr(current, tokens[-1], tensor) def getdeepattr(module, full_name): current = module tokens = full_name.split(".") for token in tokens: current = getattr(current, token) return current def load_weights_pre_hook(module_name, weights, recursive=False): def inner(module, args): print(f"Pre hook {module_name}") local_params = {} for k, v in module.named_parameters(): if not recursive and k.count(".") != 1: continue local_params[k] = v for k, v in module.named_buffers(): if not recursive and k.count(".") != 1: continue local_params[k] = v for local_param in local_params: current_tensor = getdeepattr(module, local_param) if current_tensor.device == torch.device("meta"): # print(f"Loading {local_param}") if module_name: tensor_name = f"{module_name}.{local_param}" else: tensor_name = local_param tensor = weights.get_tensor(tensor_name) setdeepattr(module, local_param, nn.Parameter(tensor)) else: setdeepattr( module, local_param, nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), ) return inner def load_weights_post_hook(module_name, weights, recursive=False): def inner(module, args, output): print(f"Post hook {module_name}") local_params = {} for k, v in module.named_parameters(): if not recursive and k.count(".") != 1: continue local_params[k] = v for k, v in module.named_buffers(): if not recursive and k.count(".") != 1: continue local_params[k] = v for local_param in local_params: # print(f"Unloading {local_param}") current_tensor = getdeepattr(module, local_param) setdeepattr( module, local_param, nn.Parameter(current_tensor.to(device=torch.device("cpu"))), ) return output return inner def quantize( model_id: str, bits: int, groupsize: int, output_dir: str, revision: str, trust_remote_code: bool, upload_to_model_id: Optional[str], percdamp: float, act_order: bool, ): print("loading model") config = AutoConfig.from_pretrained( model_id, trust_remote_code=trust_remote_code, ) with init_empty_weights(): model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16) model = model.eval() print("LOADED model") files = weight_files(model_id, revision, extension=".safetensors") process_group, _, _ = initialize_torch_distributed() weights = Weights( files, device=torch.device("cuda:0"), dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, ) hooks = [] for name, module in model.named_modules(): def load(module, name): def _load(): load_weights_pre_hook(name, weights, recursive=True)(module, None) return _load def unload(module, name): def _unload(): load_weights_post_hook(name, weights, recursive=True)( module, None, None ) return _unload module.load = load(module, name) module.unload = unload(module, name) hooks.append( module.register_forward_pre_hook(load_weights_pre_hook(name, weights)) ) hooks.append( module.register_forward_hook(load_weights_post_hook(name, weights)) ) model.seqlen = 2048 dataset = "wikitext2" nsamples = 128 seed = None dataloader, testloader = get_loaders( dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen ) tick = time.time() quantizers = sequential( model, dataloader, DEV, nsamples, bits, groupsize, percdamp=percdamp, act_order=act_order, hooks=hooks, ) print(time.time() - tick) pack(model, quantizers, bits, groupsize) from safetensors.torch import save_file from transformers.modeling_utils import shard_checkpoint state_dict = model.state_dict() state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) max_shard_size = "10GB" shards, index = shard_checkpoint( state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" ) os.makedirs(output_dir, exist_ok=True) for shard_file, shard in shards.items(): save_file( shard, os.path.join(output_dir, shard_file), metadata={ "format": "pt", "quantized": "gptq", "origin": "text-generation-inference", }, ) if index is None: path_to_weights = os.path.join(output_dir, "model.safetensors") logger.info(f"Model weights saved in {path_to_weights}") else: save_index_file = "model.safetensors.index.json" save_index_file = os.path.join(output_dir, save_index_file) with open(save_index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) logger.info( f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " f"index located at {save_index_file}." ) config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) config.save_pretrained(output_dir) logger.info("Saved config") logger.info("Saving tokenizer") tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=trust_remote_code ) tokenizer.save_pretrained(output_dir) logger.info("Saved tokenizer") if upload_to_model_id: api = HfApi() api.upload_folder( folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model" )