chaining and more lowers resolutions

This commit is contained in:
Victor Hall 2023-01-08 18:52:39 -05:00
parent c816e25773
commit 3c921dbaa2
7 changed files with 276 additions and 57 deletions

View File

@ -6,7 +6,7 @@ Welcome to v2.0 of EveryDream trainer! Now with more diffusers and even more fea
Please join us on Discord! https://discord.gg/uheqxU6sXN
If you find this tool useful, please consider subscribing to the project on Patreon or buy me a Ko-fi.
If you find this tool useful, please consider subscribing to the project on [Patreon](https://www.patreon.com/everydream) or a one-time donation at [Ko-fi](https://ko-fi.com/everydream).
## Video tutorials
@ -32,3 +32,5 @@ Behind the scenes look at how the trainer handles multiaspect and crop jitter
[Logging](doc/LOGGING.md)
[Advanced Tweaking](doc/ATWEAKING.md)
[Chaining training sessions](doc/CHAINING.md)

3
chain.bat Normal file
View File

@ -0,0 +1,3 @@
python train.py --config chain0.json
python train.py --config chain1.json
python train.py --config chain2.json

38
chain0.json Normal file
View File

@ -0,0 +1,38 @@
{
"amp": false,
"batch_size": 12,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 0.00,
"data_root": "R:\\everydream-trainer\\training_samples\\ff7r",
"disable_textenc_training": false,
"disable_xformers": true,
"flip_p": 0.0,
"ed1_mode": true,
"gpuid": 0,
"gradient_checkpointing": true,
"grad_accum": 1,
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 2.5e-6,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 15,
"project_name": "myproj_ch0",
"resolution": 384,
"resume_ckpt": "sd_v1-5_vae",
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
"save_every_n_epochs": 99,
"save_optimizer": false,
"scale_lr": false,
"seed": -1,
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": true
}

38
chain1.json Normal file
View File

@ -0,0 +1,38 @@
{
"amp": false,
"batch_size": 7,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 0.05,
"data_root": "R:\\everydream-trainer\\training_samples\\ff7r",
"disable_textenc_training": false,
"disable_xformers": true,
"flip_p": 0.0,
"ed1_mode": true,
"gpuid": 0,
"gradient_checkpointing": true,
"grad_accum": 1,
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 1.0e-6,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 10,
"project_name": "myproj_ch0",
"resolution": 512,
"resume_ckpt": "findlast",
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
"save_every_n_epochs": 5,
"save_optimizer": false,
"scale_lr": false,
"seed": -1,
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": true
}

38
chain2.json Normal file
View File

@ -0,0 +1,38 @@
{
"amp": false,
"batch_size": 2,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 0.08,
"data_root": "R:\\everydream-trainer\\training_samples\\ff7r",
"disable_textenc_training": true,
"disable_xformers": true,
"flip_p": 0.0,
"ed1_mode": true,
"gpuid": 0,
"gradient_checkpointing": true,
"grad_accum": 5,
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 1.5e-6,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 10,
"project_name": "myproj_ch0",
"resolution": 640,
"resume_ckpt": "findlast",
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
"save_every_n_epochs": 5,
"save_optimizer": false,
"scale_lr": false,
"seed": -1,
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": true
}

View File

@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
ASPECTS11 = [[1152,1152], # 1327104 1:1
ASPECTS_1152 = [[1152,1152], # 1327104 1:1
#[1216,1088],[1088,1216], # 1323008 1.118:1
[1280,1024],[1024,1280], # 1310720 1.25:1
[1344,960],[960,1344], # 1290240 1.4:1
@ -25,7 +25,7 @@ ASPECTS11 = [[1152,1152], # 1327104 1:1
[2304,576],[576,2304], # 1327104 4:1
]
ASPECTS10 = [[1088,1088], # 1183744 1:1
ASPECTS_1088 = [[1088,1088], # 1183744 1:1
[1152,1024],[1024,1152], # 1167360 1.267:1
[1216,896],[896,1216], # 1146880 1.429:1
[1408,832],[832,1408], # 1171456 1.692:1
@ -36,7 +36,7 @@ ASPECTS10 = [[1088,1088], # 1183744 1:1
[2304,512],[512,2304], # 1179648 4.5:1
]
ASPECTS9 = [[1024,1024], # 1048576 1:1
ASPECTS_1024 = [[1024,1024], # 1048576 1:1
#[1088,960],[960,1088], # 1044480 1.125:1
[1152,896],[896,1152], # 1032192 1.286:1
[1216,832],[832,1216], # 1011712 1.462:1
@ -47,7 +47,7 @@ ASPECTS9 = [[1024,1024], # 1048576 1:1
[2048,512],[512,2048], # 1048576 4:1
]
ASPECTS8 = [[960,960], # 921600 1:1
ASPECTS_960 = [[960,960], # 921600 1:1
[1024,896],[896,1024], # 917504 1.143:1
[1088,832],[832,1088], # 905216 1.308:1
[1152,768],[768,1152], # 884736 1.5:1
@ -59,7 +59,7 @@ ASPECTS8 = [[960,960], # 921600 1:1
[2048,448],[448,2048], # 917504 4.714:1
]
ASPECTS7 = [[896,896], # 802816 1:1
ASPECTS_896 = [[896,896], # 802816 1:1
[960,832],[832,960], # 798720 1.153:1
[1024,768],[768,1024], # 786432 1.333:1
[1088,704],[704,1088], # 765952 1.545:1
@ -69,7 +69,7 @@ ASPECTS7 = [[896,896], # 802816 1:1
[1792,448],[448,1792], # 802816 4:1
]
ASPECTS6 = [[832,832], # 692224 1:1
ASPECTS_832 = [[832,832], # 692224 1:1
[896,768],[768,896], # 688128 1.167:1
[960,704],[704,960], # 675840 1.364:1
#[960,640],[640,960], # 614400 1.5:1
@ -82,7 +82,7 @@ ASPECTS6 = [[832,832], # 692224 1:1
[1600,384],[384,1600], # 614400 4.167:1
]
ASPECTS5 = [[768,768], # 589824 1:1
ASPECTS_768 = [[768,768], # 589824 1:1
[832,704],[704,832], # 585728 1.181:1
[896,640],[640,896], # 573440 1.4:1
[960,576],[576,960], # 552960 1.6:1
@ -96,7 +96,7 @@ ASPECTS5 = [[768,768], # 589824 1:1
[1472,320],[320,1472], # 470400 4.6:1
]
ASPECTS4 = [[704,704], # 501,376 1:1
ASPECTS_704 = [[704,704], # 501,376 1:1
[768,640],[640,768], # 491,520 1.2:1
[832,576],[576,832], # 458,752 1.444:1
#[896,512],[512,896], # 458,752 1.75:1
@ -109,7 +109,7 @@ ASPECTS4 = [[704,704], # 501,376 1:1
[1280,320],[320,1280], # 409,600 4:1
]
ASPECTS3 = [[640,640], # 409600 1:1
ASPECTS_640 = [[640,640], # 409600 1:1
[704,576],[576,704], # 405504 1.25:1
[768,512],[512,768], # 393216 1.5:1
[832,448],[448,832], # 372736 1.857:1
@ -119,7 +119,7 @@ ASPECTS3 = [[640,640], # 409600 1:1
[1280,320],[320,1280], # 409600 4:1
]
ASPECTS2 = [[576,576], # 331776 1:1
ASPECTS_576 = [[576,576], # 331776 1:1
[640,512],[512,640], # 327680 1.25:1
#[640,448],[448,640], # 286720 1.4286:1
[704,448],[448,704], # 314928 1.5625:1
@ -130,7 +130,7 @@ ASPECTS2 = [[576,576], # 331776 1:1
#[1280,256],[256,1280], # 327680 5:1
]
ASPECTS = [[512,512], # 262144 1:1
ASPECTS_512 = [[512,512], # 262144 1:1
[576,448],[448,576], # 258048 1.29:1
[640,384],[384,640], # 245760 1.667:1
[768,320],[320,768], # 245760 2.4:1
@ -140,14 +140,25 @@ ASPECTS = [[512,512], # 262144 1:1
[1024,256],[256,1024], # 245760 4:1
]
ASPECTS0 = [[448,448], # 200704 1:1
ASPECTS_448 = [[448,448], # 200704 1:1
[512,384],[384,512], # 196608 1.333:1
[640,320],[320,640], # 204800 2:1
[768,256],[256,768], # 196608 3:1
]
ASPECTS_384 = [[384,384], # 147456 1:1
[448,320],[320,448], # 143360 1.4:1
[512,256],[256,512], # 131072 2:1
[704,192],[192,704], # 135168 3.667:1
]
ASPECTS_256 = [[256,256], # 65536 1:1
[384,192],[192,384], # 73728 2:1
[512,128],[128,512], # 65536 4:1
]
def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
if resolution < 512:
if resolution < 256:
raise ValueError("Resolution must be at least 512")
try:
rounded_resolution = int(resolution / 64) * 64
@ -164,4 +175,18 @@ def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
raise e
def __get_all_aspects():
return [ASPECTS0, ASPECTS, ASPECTS2, ASPECTS3, ASPECTS4, ASPECTS5, ASPECTS6, ASPECTS7, ASPECTS8, ASPECTS9, ASPECTS10, ASPECTS11]
return [ASPECTS_256,
ASPECTS_384,
ASPECTS_448,
ASPECTS_512,
ASPECTS_576,
ASPECTS_640,
ASPECTS_704,
ASPECTS_768,
ASPECTS_832,
ASPECTS_896,
ASPECTS_960,
ASPECTS_1024,
ASPECTS_1088,
ASPECTS_1152
]

159
train.py
View File

@ -105,6 +105,7 @@ def setup_local_logger(args):
configures logger with file and console logging, logs args, and returns the datestamp
"""
log_path = args.logdir
if not os.path.exists(log_path):
os.makedirs(log_path)
@ -115,6 +116,7 @@ def setup_local_logger(args):
f.write(f"{json_config}")
logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log")
print(f" logging to {logfilename}")
logging.basicConfig(filename=logfilename,
level=logging.INFO,
format="%(asctime)s %(message)s",
@ -122,6 +124,7 @@ def setup_local_logger(args):
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
@ -190,30 +193,51 @@ def set_args_12gb(args):
logging.info(" Overiding adam8bit to True")
args.useadam8bit = True
def main(args):
def find_last_checkpoint(logdir):
"""
Main entry point
Finds the last checkpoint in the logdir, recursively
"""
log_time = setup_local_logger(args)
last_ckpt = None
last_date = None
for root, dirs, files in os.walk(logdir):
for file in files:
if os.path.basename(file) == "model_index.json":
curr_date = os.path.getmtime(os.path.join(root,file))
if last_date is None or curr_date > last_date:
last_date = curr_date
last_ckpt = root
assert last_ckpt, f"Could not find last checkpoint in logdir: {logdir}"
assert "errored" not in last_ckpt, f"Found last checkpoint: {last_ckpt}, but it was errored, cancelling"
print(f" {Fore.LIGHTCYAN_EX}Found last checkpoint: {last_ckpt}, resuming{Style.RESET_ALL}")
return last_ckpt
def setup_args(args):
"""
Sets defaults for missing args (possible if missing from json config)
Forces some args to be set based on others for compatibility reasons
"""
if args.resume_ckpt == "findlast":
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.ed1_mode and not args.disable_xformers:
args.disable_xformers = True
logging.info(" ED1 mode: Overiding disable_xformers to True")
if args.lowvram:
set_args_12gb(args)
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = False
if args.ed1_mode:
args.disable_xformers = True
if not args.shuffle_tags:
args.shuffle_tags = False
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20
@ -231,16 +255,32 @@ def main(args):
if args.cond_dropout > 0.26:
logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
total_batch_size = args.batch_size * args.grad_accum
if args.grad_accum > 1:
logging.info(f"{Fore.CYAN} Batch size: {args.batch_size}, grad accum: {args.grad_accum}, 'effective' batch size: {args.batch_size * args.grad_accum}{Style.RESET_ALL}")
total_batch_size = args.batch_size * args.grad_accum
if args.scale_lr is not None and args.scale_lr:
tmp_lr = args.lr
args.lr = args.lr * (total_batch_size**0.55)
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
return args
def main(args):
"""
Main entry point
"""
log_time = setup_local_logger(args)
args = setup_args(args)
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = True
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
logging.info(f"Logging to {log_folder}")
if not os.path.exists(log_folder):
@ -409,9 +449,12 @@ def main(args):
default_lr = 3e-6
curr_lr = args.lr if args.lr is not None else default_lr
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
# unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
# text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
@ -537,15 +580,7 @@ def main(args):
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
# scaler = torch.cuda.amp.GradScaler(
# #enabled=False,
# enabled=True if args.amp else False,
# init_scale=2**1,
# growth_factor=1.000001,
# backoff_factor=0.9999999,
# growth_interval=50,
# )
#logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
def collate_fn(batch):
"""
@ -607,8 +642,22 @@ def main(args):
#loss = torch.tensor(0.0, device=device, dtype=torch.float32)
try:
if args.amp:
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
#enabled=False,
enabled=True,
init_scale=1024.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
loss_log_step = []
try:
for epoch in range(args.max_epochs):
loss_epoch = []
epoch_start_time = time.time()
steps_pbar.reset()
images_per_sec_log_step = []
@ -619,8 +668,8 @@ def main(args):
with torch.no_grad():
#with autocast():
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
with autocast(enabled=args.amp):
latents = vae.encode(pixel_values, return_dict=False)
#with autocast(enabled=args.amp):
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215
@ -650,8 +699,8 @@ def main(args):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
#with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
@ -663,7 +712,10 @@ def main(args):
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
loss.backward()
if args.amp:
scaler.scale(loss).backward()
else:
loss.backward()
if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size
@ -677,28 +729,37 @@ def main(args):
param.grad *= grad_scale
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
optimizer.step()
if args.amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
lr_scheduler.step()
steps_pbar.set_postfix({"gs": global_step})
loss_step = loss.detach().item()
steps_pbar.set_postfix({"loss/step": loss_step},{"gs": global_step})
steps_pbar.update(1)
global_step += 1
images_per_sec = args.batch_size / (time.time() - step_start_time)
images_per_sec_log_step.append(images_per_sec)
loss_log_step.append(loss_step)
loss_epoch.append(loss_step)
if (global_step + 1) % args.log_step == 0:
curr_lr = lr_scheduler.get_last_lr()[0]
loss_local = loss.detach().item()
logs = {"loss/step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="loss/step", scalar_value=loss_local, global_step=global_step)
loss_local = sum(loss_log_step) / len(loss_log_step)
loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)
images_per_sec_log_step = []
#log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
if args.amp:
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
@ -732,7 +793,8 @@ def main(args):
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
del loss, batch
del batch
global_step += 1
# end of step
elapsed_epoch_time = (time.time() - epoch_start_time) / 60
@ -742,6 +804,9 @@ def main(args):
epoch_pbar.update(1)
if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch+1)
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# end of epoch
# end of training
@ -765,6 +830,14 @@ def main(args):
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
def update_old_args(t_args):
"""
Update old args to new args to deal with json config loading and missing args for compatibility
"""
if not hasattr(t_args, "shuffle_tags"):
print(f" Config json is missing 'shuffle_tags'")
t_args.__dict__["shuffle_tags"] = False
if __name__ == "__main__":
supported_resolutions = [448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
@ -776,6 +849,8 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f:
t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f))
update_old_args(t_args) # update args to support older configs
print(t_args.__dict__)
args = argparser.parse_args(namespace=t_args)
else:
print("No config file specified, using command line args")