Merge pull request #21 from ProGamerGov/patch-1

Make it possible to prune with full precision
This commit is contained in:
Victor Hall 2022-12-17 15:53:28 -08:00 committed by GitHub
commit b6abae42e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 5 deletions

View File

@ -3,7 +3,7 @@ import torch
import argparse
import glob
def prune_it(p, keep_only_ema=False):
def prune_it(p, full_precision=False, keep_only_ema=False):
print(f"prunin' in path: {p}")
size_initial = os.path.getsize(p)
nsd = dict()
@ -24,9 +24,9 @@ def prune_it(p, keep_only_ema=False):
for k in sd:
if k in ema_keys:
new_sd[k] = sd[ema_keys[k]].half()
new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]]
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
new_sd[k] = sd[k].half()
new_sd[k] = sd[k].half() if not full_precision else sd[k]
assert len(new_sd) == len(sd) - len(ema_keys)
nsd["state_dict"] = new_sd
@ -34,7 +34,7 @@ def prune_it(p, keep_only_ema=False):
sd = nsd['state_dict'].copy()
new_sd = dict()
for k in sd:
new_sd[k] = sd[k].half()
new_sd[k] = sd[k].half() if not full_precision else sd[k]
nsd['state_dict'] = new_sd
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
@ -50,6 +50,8 @@ def prune_it(p, keep_only_ema=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Pruning')
parser.add_argument('--ckpt', type=str, default=False, help='path to model ckpt')
parser.add_argument("--full", action="store_true", help="Whether or not to save the model in full precision.")
args = parser.parse_args()
ckpt = args.ckpt
prune_it(ckpt)
full_precision = args.full
prune_it(ckpt, full_precision)