Merge branch 'main' of https://github.com/victorchall/everydream-trainer into main
This commit is contained in:
commit
ee9819ab1a
|
@ -22,7 +22,7 @@ This trainer is focused on enabling fine tuning with new training data plus weav
|
|||
|
||||
To get the most out of this trainer, you will need to curate your data with captions. Luckily, there are additional tools below to help enable that, and will grow over time.
|
||||
|
||||
Check out the tools repo here: [Every Dream Tools](https://www.github.com/victorchall/everydream) for automated captioning and Laion web scraper tools so you can use real images for model preservation if you wish to step beyond micro models.
|
||||
Check out the tools repo here: [Every Dream Tools](https://github.com/victorchall/everydream#tools) for automated captioning and Laion web scraper tools so you can use real images for model preservation if you wish to step beyond micro models.
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
12
main.py
12
main.py
|
@ -29,7 +29,15 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"ckpt: {ckpt} has {pl_sd['global_step']} steps")
|
||||
sd = pl_sd["state_dict"]
|
||||
|
||||
## sd = pl_sd["state_dict"]
|
||||
if "state_dict" in pl_sd:
|
||||
print("load_state_dict from state_dict")
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
print("load_state_dict from directly")
|
||||
sd = pl_sd
|
||||
|
||||
config.model.params.ckpt_path = ckpt
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
@ -774,4 +782,4 @@ if __name__ == "__main__":
|
|||
os.rename(logdir, dst)
|
||||
if trainer.global_rank == 0:
|
||||
print("Training complete. max_steps or max_epochs reached, or we blew up.")
|
||||
print(trainer.profiler.summary())
|
||||
print(trainer.profiler.summary())
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import argparse
|
||||
import glob
|
||||
|
||||
def prune_it(p, keep_only_ema=False):
|
||||
def prune_it(p, full_precision=False, keep_only_ema=False):
|
||||
print(f"prunin' in path: {p}")
|
||||
size_initial = os.path.getsize(p)
|
||||
nsd = dict()
|
||||
|
@ -24,9 +24,9 @@ def prune_it(p, keep_only_ema=False):
|
|||
|
||||
for k in sd:
|
||||
if k in ema_keys:
|
||||
new_sd[k] = sd[ema_keys[k]].half()
|
||||
new_sd[k] = sd[ema_keys[k]].half() if not full_precision else sd[ema_keys[k]]
|
||||
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
|
||||
new_sd[k] = sd[k].half()
|
||||
new_sd[k] = sd[k].half() if not full_precision else sd[k]
|
||||
|
||||
assert len(new_sd) == len(sd) - len(ema_keys)
|
||||
nsd["state_dict"] = new_sd
|
||||
|
@ -34,7 +34,7 @@ def prune_it(p, keep_only_ema=False):
|
|||
sd = nsd['state_dict'].copy()
|
||||
new_sd = dict()
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k].half()
|
||||
new_sd[k] = sd[k].half() if not full_precision else sd[k]
|
||||
nsd['state_dict'] = new_sd
|
||||
|
||||
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
|
||||
|
@ -50,6 +50,8 @@ def prune_it(p, keep_only_ema=False):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Pruning')
|
||||
parser.add_argument('--ckpt', type=str, default=False, help='path to model ckpt')
|
||||
parser.add_argument("--full", action="store_true", help="Whether or not to save the model in full precision.")
|
||||
args = parser.parse_args()
|
||||
ckpt = args.ckpt
|
||||
prune_it(ckpt)
|
||||
full_precision = args.full
|
||||
prune_it(ckpt, full_precision)
|
||||
|
|
Loading…
Reference in New Issue