diff --git a/scripts/prune_ckpt.py b/scripts/prune_ckpt.py index 9532cdf..ba81a14 100644 --- a/scripts/prune_ckpt.py +++ b/scripts/prune_ckpt.py @@ -24,30 +24,21 @@ def prune_it(p, full_precision=False, keep_only_ema=False): for k in sd: if k in ema_keys: - if full_precision: - new_sd[k] = sd[ema_keys[k]] - else: - new_sd[k] = sd[ema_keys[k]].half() - new_sd = dict() - for k in sd: - if full_precision: - new_sd[k] = sd[k] - else: - new_sd[k] = sd[k].half() - nsd['state_dict'] = new_sd - - fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt" - print(f"saving pruned checkpoint at: {fn}") + new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]] elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: - if full_precision: - new_sd[k] = sd[k] - else: - new_sd[k] = sd[k].half() + new_sd[k] = sd[k].half() if not full_precision else sd[k] assert len(new_sd) == len(sd) - len(ema_keys) nsd["state_dict"] = new_sd else: sd = nsd['state_dict'].copy() + new_sd = dict() + for k in sd: + new_sd[k] = sd[k].half() if not full_precision else sd[k] + nsd['state_dict'] = new_sd + + fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt" + print(f"saving pruned checkpoint at: {fn}") torch.save(nsd, fn) newsize = os.path.getsize(fn) MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \