Fix mistake

This commit is contained in:
ProGamerGov 2022-12-10 12:59:27 -07:00 committed by GitHub
parent db203b8da1
commit 3cb3ebce66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 18 deletions

View File

@ -24,30 +24,21 @@ def prune_it(p, full_precision=False, keep_only_ema=False):
for k in sd: for k in sd:
if k in ema_keys: if k in ema_keys:
if full_precision: new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]]
new_sd[k] = sd[ema_keys[k]]
else:
new_sd[k] = sd[ema_keys[k]].half()
new_sd = dict()
for k in sd:
if full_precision:
new_sd[k] = sd[k]
else:
new_sd[k] = sd[k].half()
nsd['state_dict'] = new_sd
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
print(f"saving pruned checkpoint at: {fn}")
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
if full_precision: new_sd[k] = sd[k].half() if not full_precision else sd[k]
new_sd[k] = sd[k]
else:
new_sd[k] = sd[k].half()
assert len(new_sd) == len(sd) - len(ema_keys) assert len(new_sd) == len(sd) - len(ema_keys)
nsd["state_dict"] = new_sd nsd["state_dict"] = new_sd
else: else:
sd = nsd['state_dict'].copy() sd = nsd['state_dict'].copy()
new_sd = dict()
for k in sd:
new_sd[k] = sd[k].half() if not full_precision else sd[k]
nsd['state_dict'] = new_sd
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
print(f"saving pruned checkpoint at: {fn}")
torch.save(nsd, fn) torch.save(nsd, fn)
newsize = os.path.getsize(fn) newsize = os.path.getsize(fn)
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \ MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \