1. Improved EMA support: samples generation with arguments EMA/NOT-EMA, saving checkpoints and diffusers for both, ema_decay_target implemented.
2. enable_zero_terminal_snr separated from zero_frequency_noise_ratio.
This commit is contained in:
parent
23df727a1f
commit
5bcf9407f0
309
train.py
309
train.py
|
@ -358,17 +358,20 @@ def log_args(log_writer, args):
|
|||
log_writer.add_text("config", arglog)
|
||||
|
||||
def update_ema(model, ema_model, decay, default_device):
|
||||
# TODO: handle Unet/TE not trained
|
||||
|
||||
with torch.no_grad():
|
||||
original_model_on_same_device = model
|
||||
original_model_on_proper_device = model
|
||||
need_to_delete_original = False
|
||||
if args.ema_decay_device != default_device:
|
||||
original_model_on_other_device = deepcopy(model)
|
||||
original_model_on_same_device = original_model_on_other_device.to(args.ema_decay_device, dtype=model.dtype)
|
||||
original_model_on_proper_device = original_model_on_other_device.to(args.ema_decay_device, dtype=model.dtype)
|
||||
del original_model_on_other_device
|
||||
need_to_delete_original = True
|
||||
|
||||
params = dict(original_model_on_same_device.named_parameters())
|
||||
# original_model_on_proper_device = type(model)().to(args.ema_decay_device, dtype=model.dtype)
|
||||
# original_model_on_proper_device.load_state_dict(model.state_dict())
|
||||
|
||||
params = dict(original_model_on_proper_device.named_parameters())
|
||||
ema_params = dict(ema_model.named_parameters())
|
||||
|
||||
for name in ema_params:
|
||||
|
@ -376,7 +379,7 @@ def update_ema(model, ema_model, decay, default_device):
|
|||
ema_params[name].data = ema_params[name] * decay + params[name].data * (1.0 - decay)
|
||||
|
||||
if need_to_delete_original:
|
||||
del(original_model_on_same_device)
|
||||
del(original_model_on_proper_device)
|
||||
|
||||
def compute_snr(timesteps, noise_scheduler):
|
||||
"""
|
||||
|
@ -415,6 +418,15 @@ def compute_snr(timesteps, noise_scheduler):
|
|||
snr[snr == 0] = minimal_value
|
||||
return snr
|
||||
|
||||
def load_train_json_from_file(args):
|
||||
try:
|
||||
print(f"Loading training config from {args.config}.")
|
||||
with open(args.config, 'rt') as f:
|
||||
read_json = json.load(f)
|
||||
|
||||
args.__dict__.update(read_json)
|
||||
except Exception as config_read:
|
||||
print(f"Error on loading training config from {args.config}.")
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
|
@ -452,40 +464,22 @@ def main(args):
|
|||
@torch.no_grad()
|
||||
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
|
||||
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
|
||||
|
||||
nonlocal unet
|
||||
nonlocal text_encoder
|
||||
nonlocal unet_ema
|
||||
nonlocal text_encoder_ema
|
||||
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
global global_step
|
||||
if global_step is None or global_step == 0:
|
||||
logging.warning(" No model to save, something likely blew up on startup, not saving")
|
||||
return
|
||||
logging.info(f" * Saving diffusers model to {save_path}")
|
||||
if args.ema_decay_rate != None:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder_ema,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet_ema,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
else:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
pipeline.save_pretrained(save_path)
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
|
||||
if save_ckpt:
|
||||
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
|
||||
nonlocal save_ckpt_dir
|
||||
nonlocal save_full_precision
|
||||
#nonlocal save_path
|
||||
nonlocal yaml_name
|
||||
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
|
@ -495,17 +489,73 @@ def main(args):
|
|||
half = not save_full_precision
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
|
||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
|
||||
global global_step
|
||||
|
||||
if global_step is None or global_step == 0:
|
||||
logging.warning(" No model to save, something likely blew up on startup, not saving")
|
||||
return
|
||||
|
||||
|
||||
if args.ema_decay_rate != None:
|
||||
|
||||
|
||||
pipeline_ema = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder_ema,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet_ema,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
|
||||
diffusers_model_path = save_path + "_ema"
|
||||
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
|
||||
pipeline_ema.save_pretrained(diffusers_model_path)
|
||||
|
||||
if save_ckpt:
|
||||
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
|
||||
|
||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
|
||||
|
||||
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
|
||||
diffusers_model_path = save_path
|
||||
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
|
||||
pipeline.save_pretrained(diffusers_model_path)
|
||||
|
||||
if save_ckpt:
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
|
||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
|
||||
|
||||
|
||||
if save_optimizer_flag:
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
ed_optimizer.save(save_path)
|
||||
|
||||
|
||||
|
||||
optimizer_state_path = None
|
||||
try:
|
||||
# check for a local file
|
||||
|
@ -530,8 +580,8 @@ def main(args):
|
|||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
if args.zero_frequency_noise_ratio == -1.0:
|
||||
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
|
||||
if args.enable_zero_terminal_snr:
|
||||
# Use zero terminal SNR
|
||||
from utils.unet_utils import enforce_zero_terminal_snr
|
||||
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
|
||||
|
@ -576,18 +626,28 @@ def main(args):
|
|||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
|
||||
# TODO: handle Unet/TE not trained
|
||||
if args.ema_decay_rate != None:
|
||||
if args.ema_decay_device == device:
|
||||
unet_ema = deepcopy(unet)
|
||||
text_encoder_ema = deepcopy(text_encoder)
|
||||
else:
|
||||
unet_ema_first = deepcopy(unet)
|
||||
text_encoder_ema_first = deepcopy(text_encoder)
|
||||
unet_ema = unet_ema_first.to(args.ema_decay_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema_first.to(args.ema_decay_device, dtype=text_encoder.dtype)
|
||||
del unet_ema_first
|
||||
del text_encoder_ema_first
|
||||
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
|
||||
|
||||
if use_ema_dacay_training:
|
||||
logging.info(f"EMA decay enabled, creating EMA model.")
|
||||
|
||||
with torch.no_grad():
|
||||
if args.ema_decay_device == device:
|
||||
unet_ema = deepcopy(unet)
|
||||
text_encoder_ema = deepcopy(text_encoder)
|
||||
else:
|
||||
# unet_ema = type(unet)().to(args.ema_decay_device, dtype=unet.dtype)
|
||||
# unet_ema.load_state_dict(unet.state_dict())
|
||||
#
|
||||
# text_encoder_ema = type(unet)().to(args.ema_decay_device, dtype=text_encoder.dtype)
|
||||
# text_encoder_ema.load_state_dict(text_encoder.state_dict())
|
||||
|
||||
unet_ema_first = deepcopy(unet)
|
||||
text_encoder_ema_first = deepcopy(text_encoder)
|
||||
unet_ema = unet_ema_first.to(args.ema_decay_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema_first.to(args.ema_decay_device, dtype=text_encoder.dtype)
|
||||
del unet_ema_first
|
||||
del text_encoder_ema_first
|
||||
|
||||
|
||||
try:
|
||||
|
@ -661,6 +721,20 @@ def main(args):
|
|||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
|
||||
if use_ema_dacay_training:
|
||||
if args.ema_decay_target != None:
|
||||
# Value in range (0,1). decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.
|
||||
total_number_of_steps: float = epoch_len * args.max_epochs
|
||||
total_number_of_ema_update: float = total_number_of_steps / args.ema_decay_interval
|
||||
args.ema_decay_rate = args.ema_decay_target ** (1 / total_number_of_ema_update)
|
||||
|
||||
logging.info(f"ema_decay_target is {args.ema_decay_target}, calculated ema_decay_rate will be: {args.ema_decay_rate}.")
|
||||
|
||||
logging.info(
|
||||
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_decay_interval: {args.ema_decay_interval}, ema_decay_device: {args.ema_decay_device}.")
|
||||
|
||||
|
||||
ed_optimizer = EveryDreamOptimizer(args,
|
||||
optimizer_config,
|
||||
text_encoder,
|
||||
|
@ -676,7 +750,7 @@ def main(args):
|
|||
default_sample_steps=args.sample_steps,
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers,
|
||||
use_penultimate_clip_layer=(args.clip_skip >= 2),
|
||||
guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0
|
||||
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
|
||||
)
|
||||
|
||||
"""
|
||||
|
@ -766,12 +840,13 @@ def main(args):
|
|||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
if zero_frequency_noise_ratio > 0.0:
|
||||
if zero_frequency_noise_ratio != None:
|
||||
if zero_frequency_noise_ratio < 0:
|
||||
zero_frequency_noise_ratio = 0
|
||||
|
||||
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
|
||||
noise = torch.randn_like(latents) + zero_frequency_noise
|
||||
else:
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
bsz = latents.shape[0]
|
||||
|
||||
|
@ -826,6 +901,11 @@ def main(args):
|
|||
return model_pred, target
|
||||
|
||||
def generate_samples(global_step: int, batch):
|
||||
nonlocal unet
|
||||
nonlocal text_encoder
|
||||
nonlocal unet_ema
|
||||
nonlocal text_encoder_ema
|
||||
|
||||
with isolate_rng():
|
||||
prev_sample_steps = sample_generator.sample_steps
|
||||
sample_generator.reload_config()
|
||||
|
@ -834,20 +914,71 @@ def main(args):
|
|||
print(f" * SampleGenerator config changed, now generating images samples every " +
|
||||
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=inference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
models_info = []
|
||||
|
||||
if (args.ema_decay_rate is None) or (args.ema_decay_sample_raw_training is not None):
|
||||
models_info.append({"is_ema": False, "swap_required": False})
|
||||
|
||||
if (args.ema_decay_rate is not None) and (args.ema_decay_sample_ema_model is not None):
|
||||
models_info.append({"is_ema": True, "swap_required": args.ema_decay_device != device})
|
||||
|
||||
for model_info in models_info:
|
||||
|
||||
extra_info: str = ""
|
||||
|
||||
if model_info["is_ema"]:
|
||||
current_unet, current_text_encoder = unet_ema, text_encoder_ema
|
||||
extra_info = "ema_"
|
||||
else:
|
||||
current_unet, current_text_encoder = unet, text_encoder
|
||||
|
||||
if model_info["swap_required"]:
|
||||
with torch.no_grad():
|
||||
unet_unloaded = unet.to(args.ema_decay_device)
|
||||
del unet
|
||||
text_encoder_unloaded = text_encoder.to(args.ema_decay_device)
|
||||
del text_encoder
|
||||
|
||||
current_unet = unet_ema.to(device)
|
||||
del unet_ema
|
||||
current_text_encoder = text_encoder_ema.to(device)
|
||||
del text_encoder_ema
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=current_unet,
|
||||
text_encoder=current_text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=inference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info)
|
||||
|
||||
# Cleanup
|
||||
del inference_pipe
|
||||
|
||||
if model_info["swap_required"]:
|
||||
with torch.no_grad():
|
||||
unet = unet_unloaded.to(device)
|
||||
del unet_unloaded
|
||||
text_encoder = text_encoder_unloaded.to(device)
|
||||
del text_encoder_unloaded
|
||||
|
||||
unet_ema = current_unet.to(args.ema_decay_device)
|
||||
del current_unet
|
||||
text_encoder_ema = current_text_encoder.to(args.ema_decay_device)
|
||||
del current_text_encoder
|
||||
|
||||
gc.collect()
|
||||
|
||||
def make_save_path(epoch, global_step, prepend=""):
|
||||
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
|
||||
|
||||
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation(global_step=0,
|
||||
|
@ -874,14 +1005,7 @@ def main(args):
|
|||
for epoch in range(args.max_epochs):
|
||||
|
||||
if args.load_settings_every_epoch:
|
||||
try:
|
||||
print(f"Loading training config from {args.config}.")
|
||||
with open(args.config, 'rt') as f:
|
||||
read_json = json.load(f)
|
||||
|
||||
args.__dict__.update(read_json)
|
||||
except Exception as config_read:
|
||||
print(f"Error on loading training config from {args.config}.")
|
||||
load_train_json_from_file(args)
|
||||
|
||||
|
||||
plugin_runner.run_on_epoch_start(epoch=epoch,
|
||||
|
@ -923,11 +1047,14 @@ def main(args):
|
|||
|
||||
if args.ema_decay_rate != None:
|
||||
if ((global_step + 1) % args.ema_decay_interval) == 0:
|
||||
debug_start_time = time.time()
|
||||
# TODO: handle Unet/TE not trained
|
||||
# TODO: Remove time measurement
|
||||
update_ema(unet, unet_ema, args.ema_decay_rate, device)
|
||||
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, device)
|
||||
debug_start_time = time.time() # TODO: Remove time measurement when debug done
|
||||
|
||||
if args.disable_unet_training != True:
|
||||
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device)
|
||||
|
||||
if args.disable_textenc_training != True:
|
||||
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device)
|
||||
|
||||
debug_end_time = time.time()
|
||||
debug_elapsed_time = debug_end_time - debug_start_time
|
||||
print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.")
|
||||
|
@ -1053,14 +1180,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, argv = argparser.parse_known_args()
|
||||
|
||||
if args.config is not None:
|
||||
print(f"Loading training config from {args.config}.")
|
||||
with open(args.config, 'rt') as f:
|
||||
args.__dict__.update(json.load(f))
|
||||
if len(argv) > 0:
|
||||
print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
load_train_json_from_file(args)
|
||||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
|
@ -1109,15 +1229,16 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15, set to -1 to use zero terminal SNR noising beta schedule instead")
|
||||
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be update by (1 - ema_decay_rate) this value every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
|
||||
argparser.add_argument("--ema_decay_target", type=float, default=None, help="Value in range (0,1). decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.")
|
||||
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between updating EMA decay. EMA model will be update on every global_steps modulo decay_interval.")
|
||||
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using CPU is taking up to 4 seconds vs fraction of a second on CUDA, using CUDA will reserve around 3.2GB Vram for model.")
|
||||
argparser.add_argument("--ema_decay_sample_raw_model", action="store_true", default=False, help="Will show samples from training model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
|
||||
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_model")
|
||||
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. recommended values: 5, 1, 20. More info: https://arxiv.org/abs/2303.09556")
|
||||
argparser.add_argument("--load_settings_every_epoch", action="store_true", help="Will load 'train.json' at start of every epoch.")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
|
||||
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
|
||||
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_decay_rate) from training, and the ema_decay_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
|
||||
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.")
|
||||
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between every EMA decay update. EMA model will be update on every global_steps modulo decay_interval.")
|
||||
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' RAM will be used.")
|
||||
argparser.add_argument("--ema_decay_sample_raw_training", action="store_true", default=False, help="Will show samples from trained model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
|
||||
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_training")
|
||||
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
|
||||
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
|
||||
|
||||
# load CLI args to overwrite existing config args
|
||||
args = argparser.parse_args(args=argv, namespace=args)
|
||||
|
|
|
@ -181,7 +181,7 @@ class SampleGenerator:
|
|||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""):
|
||||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
|
@ -269,15 +269,15 @@ class SampleGenerator:
|
|||
prompt = prompts[prompt_idx]
|
||||
clean_prompt = clean_filename(prompt)
|
||||
|
||||
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
result.save(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{self.log_folder}/samples/{extra_info}gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
f.write(str(batch[prompt_idx]))
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
if batch[prompt_idx].wants_random_caption:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
|
||||
self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
|
||||
else:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
||||
self.log_writer.add_image(tag=f"{extra_info}sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
||||
sample_index += 1
|
||||
|
||||
del result
|
||||
|
|
Loading…
Reference in New Issue