From 3cb3ebce6644eeab166b01e32a6fc0b0e65ce122 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 10 Dec 2022 12:59:27 -0700 Subject: [PATCH] Fix mistake --- scripts/prune_ckpt.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/scripts/prune_ckpt.py b/scripts/prune_ckpt.py index 9532cdf..ba81a14 100644 --- a/scripts/prune_ckpt.py +++ b/scripts/prune_ckpt.py @@ -24,30 +24,21 @@ def prune_it(p, full_precision=False, keep_only_ema=False): for k in sd: if k in ema_keys: - if full_precision: - new_sd[k] = sd[ema_keys[k]] - else: - new_sd[k] = sd[ema_keys[k]].half() - new_sd = dict() - for k in sd: - if full_precision: - new_sd[k] = sd[k] - else: - new_sd[k] = sd[k].half() - nsd['state_dict'] = new_sd - - fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt" - print(f"saving pruned checkpoint at: {fn}") + new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]] elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: - if full_precision: - new_sd[k] = sd[k] - else: - new_sd[k] = sd[k].half() + new_sd[k] = sd[k].half() if not full_precision else sd[k] assert len(new_sd) == len(sd) - len(ema_keys) nsd["state_dict"] = new_sd else: sd = nsd['state_dict'].copy() + new_sd = dict() + for k in sd: + new_sd[k] = sd[k].half() if not full_precision else sd[k] + nsd['state_dict'] = new_sd + + fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt" + print(f"saving pruned checkpoint at: {fn}") torch.save(nsd, fn) newsize = os.path.getsize(fn) MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \