Merge pull request #21 from ProGamerGov/patch-1
Make it possible to prune with full precision
This commit is contained in:
commit
b6abae42e4
|
@ -3,7 +3,7 @@ import torch
|
|||
import argparse
|
||||
import glob
|
||||
|
||||
def prune_it(p, keep_only_ema=False):
|
||||
def prune_it(p, full_precision=False, keep_only_ema=False):
|
||||
print(f"prunin' in path: {p}")
|
||||
size_initial = os.path.getsize(p)
|
||||
nsd = dict()
|
||||
|
@ -24,9 +24,9 @@ def prune_it(p, keep_only_ema=False):
|
|||
|
||||
for k in sd:
|
||||
if k in ema_keys:
|
||||
new_sd[k] = sd[ema_keys[k]].half()
|
||||
new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]]
|
||||
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
|
||||
new_sd[k] = sd[k].half()
|
||||
new_sd[k] = sd[k].half() if not full_precision else sd[k]
|
||||
|
||||
assert len(new_sd) == len(sd) - len(ema_keys)
|
||||
nsd["state_dict"] = new_sd
|
||||
|
@ -34,7 +34,7 @@ def prune_it(p, keep_only_ema=False):
|
|||
sd = nsd['state_dict'].copy()
|
||||
new_sd = dict()
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k].half()
|
||||
new_sd[k] = sd[k].half() if not full_precision else sd[k]
|
||||
nsd['state_dict'] = new_sd
|
||||
|
||||
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
|
||||
|
@ -50,6 +50,8 @@ def prune_it(p, keep_only_ema=False):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Pruning')
|
||||
parser.add_argument('--ckpt', type=str, default=False, help='path to model ckpt')
|
||||
parser.add_argument("--full", action="store_true", help="Whether or not to save the model in full precision.")
|
||||
args = parser.parse_args()
|
||||
ckpt = args.ckpt
|
||||
prune_it(ckpt)
|
||||
full_precision = args.full
|
||||
prune_it(ckpt, full_precision)
|
||||
|
|
Loading…
Reference in New Issue