Fix mistake
This commit is contained in:
parent
db203b8da1
commit
3cb3ebce66
|
@ -24,30 +24,21 @@ def prune_it(p, full_precision=False, keep_only_ema=False):
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
if k in ema_keys:
|
if k in ema_keys:
|
||||||
if full_precision:
|
new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]]
|
||||||
new_sd[k] = sd[ema_keys[k]]
|
|
||||||
else:
|
|
||||||
new_sd[k] = sd[ema_keys[k]].half()
|
|
||||||
new_sd = dict()
|
|
||||||
for k in sd:
|
|
||||||
if full_precision:
|
|
||||||
new_sd[k] = sd[k]
|
|
||||||
else:
|
|
||||||
new_sd[k] = sd[k].half()
|
|
||||||
nsd['state_dict'] = new_sd
|
|
||||||
|
|
||||||
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
|
|
||||||
print(f"saving pruned checkpoint at: {fn}")
|
|
||||||
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
|
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
|
||||||
if full_precision:
|
new_sd[k] = sd[k].half() if not full_precision else sd[k]
|
||||||
new_sd[k] = sd[k]
|
|
||||||
else:
|
|
||||||
new_sd[k] = sd[k].half()
|
|
||||||
|
|
||||||
assert len(new_sd) == len(sd) - len(ema_keys)
|
assert len(new_sd) == len(sd) - len(ema_keys)
|
||||||
nsd["state_dict"] = new_sd
|
nsd["state_dict"] = new_sd
|
||||||
else:
|
else:
|
||||||
sd = nsd['state_dict'].copy()
|
sd = nsd['state_dict'].copy()
|
||||||
|
new_sd = dict()
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k].half() if not full_precision else sd[k]
|
||||||
|
nsd['state_dict'] = new_sd
|
||||||
|
|
||||||
|
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
|
||||||
|
print(f"saving pruned checkpoint at: {fn}")
|
||||||
torch.save(nsd, fn)
|
torch.save(nsd, fn)
|
||||||
newsize = os.path.getsize(fn)
|
newsize = os.path.getsize(fn)
|
||||||
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
|
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
|
||||||
|
|
Loading…
Reference in New Issue