fix issues and improve sample generator

This commit is contained in:
Damian Stewart 2023-03-02 13:03:50 +01:00
parent c82664b3f3
commit 8100e42159
2 changed files with 48 additions and 31 deletions

View File

@ -125,12 +125,12 @@ def setup_local_logger(args):
return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr):
"""
logs the optimizer settings
"""
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
"""
@ -526,7 +526,7 @@ def main(args):
curr_lr = default_lr
logging.warning(f"No LR setting found, defaulting to {default_lr}")
text_encoder_lr = curr_lr * text_encoder_lr_scale
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
@ -537,7 +537,7 @@ def main(args):
else:
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = [{'params': unet.parameters()},
{'params': text_encoder.parameters(), 'lr': text_encoder_lr}]
{'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}]
if optimizer_name:
if optimizer_name == "lion":
@ -565,7 +565,7 @@ def main(args):
amsgrad=False,
)
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr)
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
image_train_items = resolve_image_train_items(args, log_folder)
@ -618,6 +618,7 @@ def main(args):
default_resolution=args.resolution, default_seed=args.seed,
config_file_path=args.sample_prompts,
batch_size=max(1,args.batch_size//2),
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers)
"""
@ -751,11 +752,32 @@ def main(args):
return model_pred, target
def generate_samples(global_step: int, batch):
with isolate_rng():
sample_generator.reload_config()
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation_if_appropriate(epoch=0, global_step=0,
get_model_prediction_and_target_callable=get_model_prediction_and_target)
# the sample generator might be configured to generate samples before step 0
if sample_generator.generate_pretrain_samples:
_, batch = next(enumerate(train_dataloader))
generate_samples(global_step=0, batch=batch)
try:
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
@ -813,13 +835,11 @@ def main(args):
loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
if args.disable_textenc_training or args.disable_unet_training:
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step = global_step)
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
else:
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step)
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
log_writer.add_scalars(main_tag="hyperparamater/lr", tag_scalar_dict={
'unet': curr_lr,
'text encoder': curr_text_encoder_lr
}, global_step = global_step)
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)
@ -830,21 +850,8 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
if (global_step + 1) % args.sample_steps == 0:
with isolate_rng():
sample_generator.reload_config()
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
if (global_step + 1) % sample_generator.sample_steps == 0:
generate_samples(global_step=global_step, batch=batch)
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60

View File

@ -73,6 +73,7 @@ class SampleGenerator:
config_file_path: str,
batch_size: int,
default_seed: int,
default_sample_steps: int,
use_xformers: bool):
self.log_folder = log_folder
self.log_writer = log_writer
@ -80,10 +81,13 @@ class SampleGenerator:
self.config_file_path = config_file_path
self.use_xformers = use_xformers
self.show_progress_bars = False
self.generate_pretrain_samples = False
self.default_resolution = default_resolution
self.default_seed = default_seed
self.sample_steps = default_sample_steps
self.sample_requests = None
self.reload_config()
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
if not os.path.exists(f"{log_folder}/samples/"):
@ -102,8 +106,12 @@ class SampleGenerator:
logging.warning(
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
logging.warning(
f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.")
self.sample_requests = self._make_random_caption_sample_requests()
f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated."
)
if self.sample_requests == None:
logging.warning(
f" Will generate samples from random training image captions until the problem is fixed.")
self.sample_requests = self._make_random_caption_sample_requests()
def update_random_captions(self, possible_captions: list[str]):
random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption]
@ -139,9 +147,11 @@ class SampleGenerator:
self.scheduler = config.get('scheduler', self.scheduler)
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
sample_requests_json = config.get('samples', None)
if sample_requests_json is None:
self.sample_requests = []
self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples)
self.sample_steps = config.get('sample_steps', self.sample_steps)
sample_requests_config = config.get('samples', None)
if sample_requests_config is None:
self.sample_requests = self._make_random_caption_sample_requests()
else:
default_seed = config.get('seed', self.default_seed)
default_size = (self.default_resolution, self.default_resolution)
@ -150,7 +160,7 @@ class SampleGenerator:
seed=p.get('seed', default_seed),
size=tuple(p.get('size', default_size)),
wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_json]
) for p in sample_requests_config]
if len(self.sample_requests) == 0:
self._make_random_caption_sample_requests()