Minor updates
This commit is contained in:
parent
2671369079
commit
67154ad9b5
|
@ -385,21 +385,20 @@
|
|||
"source": [
|
||||
"# START THE TRAINING\n",
|
||||
"project_name = \"project_name\"\n",
|
||||
"batch_size = 3000\n",
|
||||
"class_word = \"person\" # << match this word to the class word from regularization images above\n",
|
||||
"batch_size = 1000\n",
|
||||
"class_word = \"woman\" # << match this word to the class word from regularization images above\n",
|
||||
"\n",
|
||||
"!rm -rf training_samples/.ipynb_checkpoints\n",
|
||||
"!python \"main.py\" \\\n",
|
||||
" --base configs/stable-diffusion/v1-finetune_unfrozen.yaml \\\n",
|
||||
" -t \\\n",
|
||||
" --actual_resume \"model.ckpt\" \\\n",
|
||||
" --reg_data_root \"/workspace/Dreambooth-Stable-Diffusion/outputs/txt2img-samples/samples/\" + {dataset} \\\n",
|
||||
" --reg_data_root \"/workspace/Dreambooth-Stable-Diffusion/outputs/txt2img-samples/samples/woman_ddim\" \\\n",
|
||||
" -n {project_name} \\\n",
|
||||
" --gpus 0, \\\n",
|
||||
" --data_root \"/workspace/Dreambooth-Stable-Diffusion/training_samples\" \\\n",
|
||||
" --batch_size {batch_size} \\\n",
|
||||
" --class_word {class_word} \\\n",
|
||||
" --save_checkpoints true"
|
||||
" --class_word class_word"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -429,7 +428,7 @@
|
|||
"source": [
|
||||
"# This version should automatically prune around 10GB from the ckpt file\n",
|
||||
"last_checkpoint_file = directory_paths[-1] + \"/checkpoints/last.ckpt\"\n",
|
||||
"!python \"prune-ckpt.py\" --ckpt {last_checkpoint_file}"
|
||||
"!python \"prune_ckpt.py\" --ckpt {last_checkpoint_file}"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
|
@ -443,7 +442,7 @@
|
|||
"last_checkpoint_file_pruned = directory_paths[-1] + \"/checkpoints/last-pruned.ckpt\"\n",
|
||||
"training_samples = !ls training_samples\n",
|
||||
"date_string = !date +\"%Y-%m-%dT%H-%M-%S\"\n",
|
||||
"file_name = date_string[-1] + \"_\" + {project_name} + \"_\" + len(training_samples) + \"_training_images_\" + {batch_size} + \"_batch_size_\" + {class_word} + \"_class_word.ckpt\"\n",
|
||||
"file_name = date_string[-1] + \"_\" + project_name + \"_\" + str(len(training_samples)) + \"_training_images_\" + str(batch_size) + \"_batch_size_\" + class_word + \"_class_word.ckpt\"\n",
|
||||
"!mkdir trained_models\n",
|
||||
"!mv {last_checkpoint_file_pruned} trained_models/{file_name}"
|
||||
],
|
||||
|
|
37
main.py
37
main.py
|
@ -1,7 +1,6 @@
|
|||
import argparse, os, sys, datetime, glob, importlib, csv
|
||||
import numpy as np
|
||||
import time
|
||||
import shutil
|
||||
import torch
|
||||
|
||||
import torchvision
|
||||
|
@ -19,12 +18,9 @@ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateM
|
|||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
import prune_ckpt
|
||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from prune_ckpt import *
|
||||
|
||||
## Un-comment this for windows
|
||||
## os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "gloo"
|
||||
|
||||
|
@ -160,13 +156,6 @@ def get_parser(**parser_kwargs):
|
|||
default=1000,
|
||||
help="Number of iterations to run")
|
||||
|
||||
parser.add_argument(
|
||||
"--save_checkpoints",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Save checkpoint files to \"trained_models\""
|
||||
)
|
||||
|
||||
parser.add_argument("--actual_resume",
|
||||
type=str,
|
||||
required=True,
|
||||
|
@ -762,12 +751,11 @@ if __name__ == "__main__":
|
|||
callbacks_cfg = OmegaConf.create()
|
||||
|
||||
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
||||
print(
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||
print('Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint':
|
||||
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
'metrics_over_trainsteps_checkpoint': {
|
||||
"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
"verbose": True,
|
||||
|
@ -775,8 +763,9 @@ if __name__ == "__main__":
|
|||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
|
@ -836,22 +825,8 @@ if __name__ == "__main__":
|
|||
if trainer.global_rank == 0:
|
||||
print("Here comes the checkpoint...")
|
||||
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
||||
pruned_ckpt_path = os.path.join(ckptdir, "last-pruned.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
prune_ckpt.prune_it(ckpt_path)
|
||||
|
||||
# remove the 12gb checkpoint file
|
||||
os.remove(ckpt_path)
|
||||
|
||||
# rename the 2gb checkpoint file
|
||||
os.rename(pruned_ckpt_path, ckpt_path)
|
||||
|
||||
if opt.save_checkpoints:
|
||||
dst = os.path.join("trained_models")
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
# Setup the checkpoint filename
|
||||
checkpoint_file = os.path.join(dst, int(time.time()) + "_" + trainer.global_step + "_checkpoint.ckpt")
|
||||
shutil.copyfile(ckpt_path, checkpoint_file)
|
||||
|
||||
def divein(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
|
|
Loading…
Reference in New Issue