This commit is contained in:
Victor Hall 2022-12-28 13:11:41 -05:00
commit ee9819ab1a
3 changed files with 18 additions and 8 deletions

View File

@ -22,7 +22,7 @@ This trainer is focused on enabling fine tuning with new training data plus weav
To get the most out of this trainer, you will need to curate your data with captions. Luckily, there are additional tools below to help enable that, and will grow over time.
Check out the tools repo here: [Every Dream Tools](https://www.github.com/victorchall/everydream) for automated captioning and Laion web scraper tools so you can use real images for model preservation if you wish to step beyond micro models.
Check out the tools repo here: [Every Dream Tools](https://github.com/victorchall/everydream#tools) for automated captioning and Laion web scraper tools so you can use real images for model preservation if you wish to step beyond micro models.
## Installation

12
main.py
View File

@ -29,7 +29,15 @@ def load_model_from_config(config, ckpt, verbose=False):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"ckpt: {ckpt} has {pl_sd['global_step']} steps")
sd = pl_sd["state_dict"]
## sd = pl_sd["state_dict"]
if "state_dict" in pl_sd:
print("load_state_dict from state_dict")
sd = pl_sd["state_dict"]
else:
print("load_state_dict from directly")
sd = pl_sd
config.model.params.ckpt_path = ckpt
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
@ -774,4 +782,4 @@ if __name__ == "__main__":
os.rename(logdir, dst)
if trainer.global_rank == 0:
print("Training complete. max_steps or max_epochs reached, or we blew up.")
print(trainer.profiler.summary())
print(trainer.profiler.summary())

View File

@ -3,7 +3,7 @@ import torch
import argparse
import glob
def prune_it(p, keep_only_ema=False):
def prune_it(p, full_precision=False, keep_only_ema=False):
print(f"prunin' in path: {p}")
size_initial = os.path.getsize(p)
nsd = dict()
@ -24,9 +24,9 @@ def prune_it(p, keep_only_ema=False):
for k in sd:
if k in ema_keys:
new_sd[k] = sd[ema_keys[k]].half()
new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]]
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
new_sd[k] = sd[k].half()
new_sd[k] = sd[k].half() if not full_precision else sd[k]
assert len(new_sd) == len(sd) - len(ema_keys)
nsd["state_dict"] = new_sd
@ -34,7 +34,7 @@ def prune_it(p, keep_only_ema=False):
sd = nsd['state_dict'].copy()
new_sd = dict()
for k in sd:
new_sd[k] = sd[k].half()
new_sd[k] = sd[k].half() if not full_precision else sd[k]
nsd['state_dict'] = new_sd
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
@ -50,6 +50,8 @@ def prune_it(p, keep_only_ema=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Pruning')
parser.add_argument('--ckpt', type=str, default=False, help='path to model ckpt')
parser.add_argument("--full", action="store_true", help="Whether or not to save the model in full precision.")
args = parser.parse_args()
ckpt = args.ckpt
prune_it(ckpt)
full_precision = args.full
prune_it(ckpt, full_precision)