fix issues and improve sample generator
This commit is contained in:
parent
c82664b3f3
commit
8100e42159
57
train.py
57
train.py
|
@ -125,12 +125,12 @@ def setup_local_logger(args):
|
|||
|
||||
return datetimestamp
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
|
||||
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
|
@ -526,7 +526,7 @@ def main(args):
|
|||
curr_lr = default_lr
|
||||
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
||||
|
||||
text_encoder_lr = curr_lr * text_encoder_lr_scale
|
||||
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
|
@ -537,7 +537,7 @@ def main(args):
|
|||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = [{'params': unet.parameters()},
|
||||
{'params': text_encoder.parameters(), 'lr': text_encoder_lr}]
|
||||
{'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}]
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
|
@ -565,7 +565,7 @@ def main(args):
|
|||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr)
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
|
@ -618,6 +618,7 @@ def main(args):
|
|||
default_resolution=args.resolution, default_seed=args.seed,
|
||||
config_file_path=args.sample_prompts,
|
||||
batch_size=max(1,args.batch_size//2),
|
||||
default_sample_steps=args.sample_steps,
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers)
|
||||
|
||||
"""
|
||||
|
@ -751,11 +752,32 @@ def main(args):
|
|||
|
||||
return model_pred, target
|
||||
|
||||
def generate_samples(global_step: int, batch):
|
||||
with isolate_rng():
|
||||
sample_generator.reload_config()
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch=0, global_step=0,
|
||||
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
||||
|
||||
# the sample generator might be configured to generate samples before step 0
|
||||
if sample_generator.generate_pretrain_samples:
|
||||
_, batch = next(enumerate(train_dataloader))
|
||||
generate_samples(global_step=0, batch=batch)
|
||||
|
||||
try:
|
||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||
|
||||
|
@ -813,13 +835,11 @@ def main(args):
|
|||
loss_log_step = []
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
if args.disable_textenc_training or args.disable_unet_training:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step = global_step)
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||
else:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step)
|
||||
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
|
||||
log_writer.add_scalars(main_tag="hyperparamater/lr", tag_scalar_dict={
|
||||
'unet': curr_lr,
|
||||
'text encoder': curr_text_encoder_lr
|
||||
}, global_step = global_step)
|
||||
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
|
@ -830,21 +850,8 @@ def main(args):
|
|||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if (global_step + 1) % args.sample_steps == 0:
|
||||
with isolate_rng():
|
||||
sample_generator.reload_config()
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if (global_step + 1) % sample_generator.sample_steps == 0:
|
||||
generate_samples(global_step=global_step, batch=batch)
|
||||
|
||||
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ class SampleGenerator:
|
|||
config_file_path: str,
|
||||
batch_size: int,
|
||||
default_seed: int,
|
||||
default_sample_steps: int,
|
||||
use_xformers: bool):
|
||||
self.log_folder = log_folder
|
||||
self.log_writer = log_writer
|
||||
|
@ -80,10 +81,13 @@ class SampleGenerator:
|
|||
self.config_file_path = config_file_path
|
||||
self.use_xformers = use_xformers
|
||||
self.show_progress_bars = False
|
||||
self.generate_pretrain_samples = False
|
||||
|
||||
self.default_resolution = default_resolution
|
||||
self.default_seed = default_seed
|
||||
self.sample_steps = default_sample_steps
|
||||
|
||||
self.sample_requests = None
|
||||
self.reload_config()
|
||||
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
|
||||
if not os.path.exists(f"{log_folder}/samples/"):
|
||||
|
@ -102,8 +106,12 @@ class SampleGenerator:
|
|||
logging.warning(
|
||||
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
|
||||
logging.warning(
|
||||
f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.")
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated."
|
||||
)
|
||||
if self.sample_requests == None:
|
||||
logging.warning(
|
||||
f" Will generate samples from random training image captions until the problem is fixed.")
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
def update_random_captions(self, possible_captions: list[str]):
|
||||
random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption]
|
||||
|
@ -139,9 +147,11 @@ class SampleGenerator:
|
|||
self.scheduler = config.get('scheduler', self.scheduler)
|
||||
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
|
||||
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
|
||||
sample_requests_json = config.get('samples', None)
|
||||
if sample_requests_json is None:
|
||||
self.sample_requests = []
|
||||
self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples)
|
||||
self.sample_steps = config.get('sample_steps', self.sample_steps)
|
||||
sample_requests_config = config.get('samples', None)
|
||||
if sample_requests_config is None:
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
else:
|
||||
default_seed = config.get('seed', self.default_seed)
|
||||
default_size = (self.default_resolution, self.default_resolution)
|
||||
|
@ -150,7 +160,7 @@ class SampleGenerator:
|
|||
seed=p.get('seed', default_seed),
|
||||
size=tuple(p.get('size', default_size)),
|
||||
wants_random_caption=p.get('random_caption', False)
|
||||
) for p in sample_requests_json]
|
||||
) for p in sample_requests_config]
|
||||
if len(self.sample_requests) == 0:
|
||||
self._make_random_caption_sample_requests()
|
||||
|
||||
|
|
Loading…
Reference in New Issue