Minor updates

This commit is contained in:
David Bielejeski 2022-09-26 12:41:03 -05:00
parent 2671369079
commit 67154ad9b5
2 changed files with 12 additions and 38 deletions

View File

@ -385,21 +385,20 @@
"source": [
"# START THE TRAINING\n",
"project_name = \"project_name\"\n",
"batch_size = 3000\n",
"class_word = \"person\" # << match this word to the class word from regularization images above\n",
"batch_size = 1000\n",
"class_word = \"woman\" # << match this word to the class word from regularization images above\n",
"\n",
"!rm -rf training_samples/.ipynb_checkpoints\n",
"!python \"main.py\" \\\n",
" --base configs/stable-diffusion/v1-finetune_unfrozen.yaml \\\n",
" -t \\\n",
" --actual_resume \"model.ckpt\" \\\n",
" --reg_data_root \"/workspace/Dreambooth-Stable-Diffusion/outputs/txt2img-samples/samples/\" + {dataset} \\\n",
" --reg_data_root \"/workspace/Dreambooth-Stable-Diffusion/outputs/txt2img-samples/samples/woman_ddim\" \\\n",
" -n {project_name} \\\n",
" --gpus 0, \\\n",
" --data_root \"/workspace/Dreambooth-Stable-Diffusion/training_samples\" \\\n",
" --batch_size {batch_size} \\\n",
" --class_word {class_word} \\\n",
" --save_checkpoints true"
" --class_word class_word"
]
},
{
@ -429,7 +428,7 @@
"source": [
"# This version should automatically prune around 10GB from the ckpt file\n",
"last_checkpoint_file = directory_paths[-1] + \"/checkpoints/last.ckpt\"\n",
"!python \"prune-ckpt.py\" --ckpt {last_checkpoint_file}"
"!python \"prune_ckpt.py\" --ckpt {last_checkpoint_file}"
],
"metadata": {
"collapsed": false
@ -443,7 +442,7 @@
"last_checkpoint_file_pruned = directory_paths[-1] + \"/checkpoints/last-pruned.ckpt\"\n",
"training_samples = !ls training_samples\n",
"date_string = !date +\"%Y-%m-%dT%H-%M-%S\"\n",
"file_name = date_string[-1] + \"_\" + {project_name} + \"_\" + len(training_samples) + \"_training_images_\" + {batch_size} + \"_batch_size_\" + {class_word} + \"_class_word.ckpt\"\n",
"file_name = date_string[-1] + \"_\" + project_name + \"_\" + str(len(training_samples)) + \"_training_images_\" + str(batch_size) + \"_batch_size_\" + class_word + \"_class_word.ckpt\"\n",
"!mkdir trained_models\n",
"!mv {last_checkpoint_file_pruned} trained_models/{file_name}"
],

37
main.py
View File

@ -1,7 +1,6 @@
import argparse, os, sys, datetime, glob, importlib, csv
import numpy as np
import time
import shutil
import torch
import torchvision
@ -19,12 +18,9 @@ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateM
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
import prune_ckpt
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
from prune_ckpt import *
## Un-comment this for windows
## os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "gloo"
@ -160,13 +156,6 @@ def get_parser(**parser_kwargs):
default=1000,
help="Number of iterations to run")
parser.add_argument(
"--save_checkpoints",
type=bool,
default=False,
help="Save checkpoint files to \"trained_models\""
)
parser.add_argument("--actual_resume",
type=str,
required=True,
@ -762,12 +751,11 @@ if __name__ == "__main__":
callbacks_cfg = OmegaConf.create()
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
print('Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint':
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
'metrics_over_trainsteps_checkpoint': {
"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
@ -775,8 +763,9 @@ if __name__ == "__main__":
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
@ -836,22 +825,8 @@ if __name__ == "__main__":
if trainer.global_rank == 0:
print("Here comes the checkpoint...")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
pruned_ckpt_path = os.path.join(ckptdir, "last-pruned.ckpt")
trainer.save_checkpoint(ckpt_path)
prune_ckpt.prune_it(ckpt_path)
# remove the 12gb checkpoint file
os.remove(ckpt_path)
# rename the 2gb checkpoint file
os.rename(pruned_ckpt_path, ckpt_path)
if opt.save_checkpoints:
dst = os.path.join("trained_models")
os.makedirs(os.path.split(dst)[0], exist_ok=True)
# Setup the checkpoint filename
checkpoint_file = os.path.join(dst, int(time.time()) + "_" + trainer.global_step + "_checkpoint.ckpt")
shutil.copyfile(ckpt_path, checkpoint_file)
def divein(*args, **kwargs):
if trainer.global_rank == 0: