Fix mistake
This commit is contained in:
parent
db203b8da1
commit
3cb3ebce66
|
@ -24,30 +24,21 @@ def prune_it(p, full_precision=False, keep_only_ema=False):
|
|||
|
||||
for k in sd:
|
||||
if k in ema_keys:
|
||||
if full_precision:
|
||||
new_sd[k] = sd[ema_keys[k]]
|
||||
else:
|
||||
new_sd[k] = sd[ema_keys[k]].half()
|
||||
new_sd = dict()
|
||||
for k in sd:
|
||||
if full_precision:
|
||||
new_sd[k] = sd[k]
|
||||
else:
|
||||
new_sd[k] = sd[k].half()
|
||||
nsd['state_dict'] = new_sd
|
||||
|
||||
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
|
||||
print(f"saving pruned checkpoint at: {fn}")
|
||||
new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]]
|
||||
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
|
||||
if full_precision:
|
||||
new_sd[k] = sd[k]
|
||||
else:
|
||||
new_sd[k] = sd[k].half()
|
||||
new_sd[k] = sd[k].half() if not full_precision else sd[k]
|
||||
|
||||
assert len(new_sd) == len(sd) - len(ema_keys)
|
||||
nsd["state_dict"] = new_sd
|
||||
else:
|
||||
sd = nsd['state_dict'].copy()
|
||||
new_sd = dict()
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k].half() if not full_precision else sd[k]
|
||||
nsd['state_dict'] = new_sd
|
||||
|
||||
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
|
||||
print(f"saving pruned checkpoint at: {fn}")
|
||||
torch.save(nsd, fn)
|
||||
newsize = os.path.getsize(fn)
|
||||
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
|
||||
|
|
Loading…
Reference in New Issue