Merge branch 'main' of https://github.com/victorchall/EveryDream2trainer
This commit is contained in:
commit
2e3d044ba3
|
@ -116,7 +116,7 @@ class EveryDreamValidator:
|
||||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||||
with torch.no_grad(), isolate_rng():
|
with torch.no_grad(), isolate_rng():
|
||||||
loss_validation_epoch = []
|
loss_validation_epoch = []
|
||||||
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
|
||||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
||||||
|
|
||||||
for step, batch in enumerate(dataloader):
|
for step, batch in enumerate(dataloader):
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import typing
|
import typing
|
||||||
import zipfile
|
import zipfile
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -18,7 +17,6 @@ class DataResolver:
|
||||||
"""
|
"""
|
||||||
self.aspects = args.aspects
|
self.aspects = args.aspects
|
||||||
self.flip_p = args.flip_p
|
self.flip_p = args.flip_p
|
||||||
self.seed = args.seed
|
|
||||||
|
|
||||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||||
"""
|
"""
|
||||||
|
@ -116,7 +114,6 @@ class DirectoryResolver(DataResolver):
|
||||||
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
||||||
items = []
|
items = []
|
||||||
multipliers = {}
|
multipliers = {}
|
||||||
randomizer = random.Random(self.seed)
|
|
||||||
|
|
||||||
for pathname in tqdm.tqdm(image_paths):
|
for pathname in tqdm.tqdm(image_paths):
|
||||||
current_dir = os.path.dirname(pathname)
|
current_dir = os.path.dirname(pathname)
|
||||||
|
|
|
@ -28,6 +28,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
|
||||||
"scheduler": "dpm++",
|
"scheduler": "dpm++",
|
||||||
"num_inference_steps": 15,
|
"num_inference_steps": 15,
|
||||||
"show_progress_bars": true,
|
"show_progress_bars": true,
|
||||||
|
"generate_samples_every_n_steps": 200,
|
||||||
|
"generate_pretrain_samples": true,
|
||||||
"samples": [
|
"samples": [
|
||||||
{
|
{
|
||||||
"prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background",
|
"prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background",
|
||||||
|
@ -35,7 +37,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"prompt": "a photograph of ted bennet riding a bicycle",
|
"prompt": "a photograph of ted bennet riding a bicycle",
|
||||||
"seed": -1
|
"seed": -1,
|
||||||
|
"aspect_ratio": 1.77778
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"random_caption": true,
|
"random_caption": true,
|
||||||
|
@ -45,9 +48,11 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process.
|
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. If you want to see sample progress bars you can set `show_progress_bars` to `true`. To generate a batch of samples before training begins, set `generate_pretrain_samples` to true.
|
||||||
|
|
||||||
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64). Use `"random_caption": true` to pick a random caption from the training set each time.
|
Finally, you can override the `sample_steps` set in the main configuration .json file (or CLI) by setting `generate_samples_every_n_steps`. This value is read every time samples are updated, so if you initially pass `--sample_steps 200` and then later on you edit your `sample_prompts.json` file to add `"generate_samples_every_n_steps": 100`, after the next set of samples is generated you will start seeing new sets of image samples every 100 steps instead of only every 200 steps.
|
||||||
|
|
||||||
|
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64) or `aspect_ratio` (eg 1.77778 for 16:9). Use `"random_caption": true` to pick a random caption from the training set each time.
|
||||||
|
|
||||||
## LR
|
## LR
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,8 @@ Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the
|
||||||
|
|
||||||
LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg.
|
LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg.
|
||||||
|
|
||||||
|
The text encoder LR can run at a different value to the Unet LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json`. For example, to train the text encoder with an LR that is half that of the Unet, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `1.0`, meaning the text encoder and Unet are trained with the same LR.
|
||||||
|
|
||||||
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here.
|
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here.
|
||||||
|
|
||||||
Note `lion` does not use epsilon.
|
Note `lion` does not use epsilon.
|
|
@ -5,11 +5,13 @@
|
||||||
"lr": "learning rate, if null wil use CLI or main JSON config value",
|
"lr": "learning rate, if null wil use CLI or main JSON config value",
|
||||||
"betas": "exponential decay rates for the moment estimates",
|
"betas": "exponential decay rates for the moment estimates",
|
||||||
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
||||||
"weight_decay": "weight decay (L2 penalty)"
|
"weight_decay": "weight decay (L2 penalty)",
|
||||||
|
"text_encoder_lr_scale": "scale the text encoder LR relative to the Unet LR. for example, if `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`."
|
||||||
},
|
},
|
||||||
"optimizer": "adamw8bit",
|
"optimizer": "adamw8bit",
|
||||||
"lr": 1e-6,
|
"lr": 1e-6,
|
||||||
"betas": [0.9, 0.999],
|
"betas": [0.9, 0.999],
|
||||||
"epsilon": 1e-8,
|
"epsilon": 1e-8,
|
||||||
"weight_decay": 0.010
|
"weight_decay": 0.010,
|
||||||
|
"text_encoder_lr_scale": 1.0
|
||||||
}
|
}
|
||||||
|
|
95
train.py
95
train.py
|
@ -125,12 +125,12 @@ def setup_local_logger(args):
|
||||||
|
|
||||||
return datetimestamp
|
return datetimestamp
|
||||||
|
|
||||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
|
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr):
|
||||||
"""
|
"""
|
||||||
logs the optimizer settings
|
logs the optimizer settings
|
||||||
"""
|
"""
|
||||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||||
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||||
|
|
||||||
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||||
"""
|
"""
|
||||||
|
@ -363,7 +363,9 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
if args.seed == -1:
|
||||||
|
args.seed = random.randint(0, 2**30)
|
||||||
|
seed = args.seed
|
||||||
logging.info(f" Seed: {seed}")
|
logging.info(f" Seed: {seed}")
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -477,16 +479,6 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||||
|
|
||||||
if args.disable_textenc_training:
|
|
||||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
|
||||||
params_to_train = itertools.chain(unet.parameters())
|
|
||||||
elif args.disable_unet_training:
|
|
||||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
|
||||||
params_to_train = itertools.chain(text_encoder.parameters())
|
|
||||||
else:
|
|
||||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
|
||||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
|
||||||
|
|
||||||
optimizer_config = None
|
optimizer_config = None
|
||||||
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
|
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
|
||||||
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
|
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
|
||||||
|
@ -514,6 +506,7 @@ def main(args):
|
||||||
|
|
||||||
default_lr = 1e-6
|
default_lr = 1e-6
|
||||||
curr_lr = args.lr
|
curr_lr = args.lr
|
||||||
|
text_encoder_lr_scale = 1.0
|
||||||
|
|
||||||
if optimizer_config is not None:
|
if optimizer_config is not None:
|
||||||
betas = optimizer_config["betas"]
|
betas = optimizer_config["betas"]
|
||||||
|
@ -524,12 +517,33 @@ def main(args):
|
||||||
if args.lr is not None:
|
if args.lr is not None:
|
||||||
curr_lr = args.lr
|
curr_lr = args.lr
|
||||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||||
|
|
||||||
|
text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
|
||||||
|
if text_encoder_lr_scale != 1.0:
|
||||||
|
logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}")
|
||||||
|
|
||||||
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
||||||
|
|
||||||
if curr_lr is None:
|
if curr_lr is None:
|
||||||
curr_lr = default_lr
|
curr_lr = default_lr
|
||||||
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
||||||
|
|
||||||
|
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
|
||||||
|
|
||||||
|
if args.disable_textenc_training:
|
||||||
|
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||||
|
params_to_train = itertools.chain(unet.parameters())
|
||||||
|
elif args.disable_unet_training:
|
||||||
|
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||||
|
if text_encoder_lr_scale != 1:
|
||||||
|
logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the "
|
||||||
|
f"Unet LR {curr_lr} for the text encoder instead *{Style.RESET_ALL}")
|
||||||
|
params_to_train = itertools.chain(text_encoder.parameters())
|
||||||
|
else:
|
||||||
|
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||||
|
params_to_train = [{'params': unet.parameters()},
|
||||||
|
{'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}]
|
||||||
|
|
||||||
if optimizer_name:
|
if optimizer_name:
|
||||||
if optimizer_name == "lion":
|
if optimizer_name == "lion":
|
||||||
from lion_pytorch import Lion
|
from lion_pytorch import Lion
|
||||||
|
@ -556,7 +570,7 @@ def main(args):
|
||||||
amsgrad=False,
|
amsgrad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr)
|
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||||
|
|
||||||
image_train_items = resolve_image_train_items(args, log_folder)
|
image_train_items = resolve_image_train_items(args, log_folder)
|
||||||
|
|
||||||
|
@ -609,6 +623,7 @@ def main(args):
|
||||||
default_resolution=args.resolution, default_seed=args.seed,
|
default_resolution=args.resolution, default_seed=args.seed,
|
||||||
config_file_path=args.sample_prompts,
|
config_file_path=args.sample_prompts,
|
||||||
batch_size=max(1,args.batch_size//2),
|
batch_size=max(1,args.batch_size//2),
|
||||||
|
default_sample_steps=args.sample_steps,
|
||||||
use_xformers=is_xformers_available() and not args.disable_xformers)
|
use_xformers=is_xformers_available() and not args.disable_xformers)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -681,7 +696,7 @@ def main(args):
|
||||||
)
|
)
|
||||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||||
|
|
||||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
|
||||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||||
epoch_times = []
|
epoch_times = []
|
||||||
|
|
||||||
|
@ -742,11 +757,37 @@ def main(args):
|
||||||
|
|
||||||
return model_pred, target
|
return model_pred, target
|
||||||
|
|
||||||
|
def generate_samples(global_step: int, batch):
|
||||||
|
with isolate_rng():
|
||||||
|
prev_sample_steps = sample_generator.sample_steps
|
||||||
|
sample_generator.reload_config()
|
||||||
|
if prev_sample_steps != sample_generator.sample_steps:
|
||||||
|
next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps
|
||||||
|
print(f" * SampleGenerator config changed, now generating images samples every " +
|
||||||
|
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
|
||||||
|
sample_generator.update_random_captions(batch["captions"])
|
||||||
|
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
vae=vae,
|
||||||
|
diffusers_scheduler_config=reference_scheduler.config
|
||||||
|
).to(device)
|
||||||
|
sample_generator.generate_samples(inference_pipe, global_step)
|
||||||
|
|
||||||
|
del inference_pipe
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Pre-train validation to establish a starting point on the loss graph
|
# Pre-train validation to establish a starting point on the loss graph
|
||||||
if validator:
|
if validator:
|
||||||
validator.do_validation_if_appropriate(epoch=0, global_step=0,
|
validator.do_validation_if_appropriate(epoch=0, global_step=0,
|
||||||
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
||||||
|
|
||||||
|
# the sample generator might be configured to generate samples before step 0
|
||||||
|
if sample_generator.generate_pretrain_samples:
|
||||||
|
_, batch = next(enumerate(train_dataloader))
|
||||||
|
generate_samples(global_step=0, batch=batch)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||||
|
|
||||||
|
@ -756,7 +797,7 @@ def main(args):
|
||||||
images_per_sec_log_step = []
|
images_per_sec_log_step = []
|
||||||
|
|
||||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||||
steps_pbar = tqdm(range(epoch_len), position=1)
|
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
||||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
@ -803,7 +844,12 @@ def main(args):
|
||||||
loss_local = sum(loss_log_step) / len(loss_log_step)
|
loss_local = sum(loss_log_step) / len(loss_log_step)
|
||||||
loss_log_step = []
|
loss_log_step = []
|
||||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||||
|
if args.disable_textenc_training or args.disable_unet_training or text_encoder_lr_scale == 1:
|
||||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||||
|
else:
|
||||||
|
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step)
|
||||||
|
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
|
||||||
|
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step)
|
||||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||||
sum_img = sum(images_per_sec_log_step)
|
sum_img = sum(images_per_sec_log_step)
|
||||||
avg = sum_img / len(images_per_sec_log_step)
|
avg = sum_img / len(images_per_sec_log_step)
|
||||||
|
@ -814,21 +860,8 @@ def main(args):
|
||||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if (global_step + 1) % args.sample_steps == 0:
|
if (global_step + 1) % sample_generator.sample_steps == 0:
|
||||||
with isolate_rng():
|
generate_samples(global_step=global_step, batch=batch)
|
||||||
sample_generator.reload_config()
|
|
||||||
sample_generator.update_random_captions(batch["captions"])
|
|
||||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
vae=vae,
|
|
||||||
diffusers_scheduler_config=reference_scheduler.config
|
|
||||||
).to(device)
|
|
||||||
sample_generator.generate_samples(inference_pipe, global_step)
|
|
||||||
|
|
||||||
del inference_pipe
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
def clean_filename(filename):
|
def clean_filename(filename):
|
||||||
|
@ -52,6 +53,15 @@ def chunk_list(l: list, batch_size: int,
|
||||||
yield b[i:i + batch_size]
|
yield b[i:i + batch_size]
|
||||||
|
|
||||||
|
|
||||||
|
def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]:
|
||||||
|
sizes = []
|
||||||
|
target_pixel_count = default_resolution * default_resolution
|
||||||
|
for w in range(256, 1024, 64):
|
||||||
|
for h in range(256, 1024, 64):
|
||||||
|
if abs((w * h) - target_pixel_count) <= 128 * 64:
|
||||||
|
sizes.append((w, h))
|
||||||
|
best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1]))))
|
||||||
|
return best_size
|
||||||
|
|
||||||
|
|
||||||
class SampleGenerator:
|
class SampleGenerator:
|
||||||
|
@ -73,6 +83,7 @@ class SampleGenerator:
|
||||||
config_file_path: str,
|
config_file_path: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
default_seed: int,
|
default_seed: int,
|
||||||
|
default_sample_steps: int,
|
||||||
use_xformers: bool):
|
use_xformers: bool):
|
||||||
self.log_folder = log_folder
|
self.log_folder = log_folder
|
||||||
self.log_writer = log_writer
|
self.log_writer = log_writer
|
||||||
|
@ -80,12 +91,15 @@ class SampleGenerator:
|
||||||
self.config_file_path = config_file_path
|
self.config_file_path = config_file_path
|
||||||
self.use_xformers = use_xformers
|
self.use_xformers = use_xformers
|
||||||
self.show_progress_bars = False
|
self.show_progress_bars = False
|
||||||
|
self.generate_pretrain_samples = False
|
||||||
|
|
||||||
self.default_resolution = default_resolution
|
self.default_resolution = default_resolution
|
||||||
self.default_seed = default_seed
|
self.default_seed = default_seed
|
||||||
|
self.sample_steps = default_sample_steps
|
||||||
|
|
||||||
|
self.sample_requests = None
|
||||||
self.reload_config()
|
self.reload_config()
|
||||||
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
|
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
|
||||||
if not os.path.exists(f"{log_folder}/samples/"):
|
if not os.path.exists(f"{log_folder}/samples/"):
|
||||||
os.makedirs(f"{log_folder}/samples/")
|
os.makedirs(f"{log_folder}/samples/")
|
||||||
|
|
||||||
|
@ -102,7 +116,11 @@ class SampleGenerator:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
|
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.")
|
f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated."
|
||||||
|
)
|
||||||
|
if self.sample_requests == None:
|
||||||
|
logging.warning(
|
||||||
|
f" Will generate samples from random training image captions until the problem is fixed.")
|
||||||
self.sample_requests = self._make_random_caption_sample_requests()
|
self.sample_requests = self._make_random_caption_sample_requests()
|
||||||
|
|
||||||
def update_random_captions(self, possible_captions: list[str]):
|
def update_random_captions(self, possible_captions: list[str]):
|
||||||
|
@ -139,18 +157,20 @@ class SampleGenerator:
|
||||||
self.scheduler = config.get('scheduler', self.scheduler)
|
self.scheduler = config.get('scheduler', self.scheduler)
|
||||||
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
|
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
|
||||||
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
|
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
|
||||||
sample_requests_json = config.get('samples', None)
|
self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples)
|
||||||
if sample_requests_json is None:
|
self.sample_steps = config.get('generate_samples_every_n_steps', self.sample_steps)
|
||||||
self.sample_requests = []
|
sample_requests_config = config.get('samples', None)
|
||||||
|
if sample_requests_config is None:
|
||||||
|
self.sample_requests = self._make_random_caption_sample_requests()
|
||||||
else:
|
else:
|
||||||
default_seed = config.get('seed', self.default_seed)
|
default_seed = config.get('seed', self.default_seed)
|
||||||
default_size = (self.default_resolution, self.default_resolution)
|
|
||||||
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
|
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
|
||||||
negative_prompt=p.get('negative_prompt', ''),
|
negative_prompt=p.get('negative_prompt', ''),
|
||||||
seed=p.get('seed', default_seed),
|
seed=p.get('seed', default_seed),
|
||||||
size=tuple(p.get('size', default_size)),
|
size=tuple(p.get('size', None) or
|
||||||
|
get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)),
|
||||||
wants_random_caption=p.get('random_caption', False)
|
wants_random_caption=p.get('random_caption', False)
|
||||||
) for p in sample_requests_json]
|
) for p in sample_requests_config]
|
||||||
if len(self.sample_requests) == 0:
|
if len(self.sample_requests) == 0:
|
||||||
self._make_random_caption_sample_requests()
|
self._make_random_caption_sample_requests()
|
||||||
|
|
||||||
|
@ -159,23 +179,26 @@ class SampleGenerator:
|
||||||
"""
|
"""
|
||||||
generates samples at different cfg scales and saves them to disk
|
generates samples at different cfg scales and saves them to disk
|
||||||
"""
|
"""
|
||||||
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
|
disable_progress_bars = not self.show_progress_bars
|
||||||
|
|
||||||
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
font = ImageFont.truetype(font="arial.ttf", size=20)
|
font = ImageFont.truetype(font="arial.ttf", size=20)
|
||||||
except:
|
except:
|
||||||
font = ImageFont.load_default()
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
if not self.show_progress_bars:
|
||||||
|
print(f" * Generating samples at gs:{global_step} for {len(self.sample_requests)} prompts")
|
||||||
|
|
||||||
sample_index = 0
|
sample_index = 0
|
||||||
with autocast():
|
with autocast():
|
||||||
batch: list[SampleRequest]
|
batch: list[SampleRequest]
|
||||||
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
|
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
|
||||||
return a.size == b.size
|
return a.size == b.size
|
||||||
for batch in chunk_list(self.sample_requests, self.batch_size,
|
batches = list(chunk_list(self.sample_requests, self.batch_size,
|
||||||
compatibility_test=sample_compatibility_test):
|
compatibility_test=sample_compatibility_test))
|
||||||
#print("batch: ", batch)
|
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
|
||||||
|
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
||||||
|
for batch in batches:
|
||||||
prompts = [p.prompt for p in batch]
|
prompts = [p.prompt for p in batch]
|
||||||
negative_prompts = [p.negative_prompt for p in batch]
|
negative_prompts = [p.negative_prompt for p in batch]
|
||||||
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
||||||
|
@ -186,6 +209,8 @@ class SampleGenerator:
|
||||||
|
|
||||||
batch_images = []
|
batch_images = []
|
||||||
for cfg in self.cfgs:
|
for cfg in self.cfgs:
|
||||||
|
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
|
||||||
|
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
|
||||||
images = pipe(prompt=prompts,
|
images = pipe(prompt=prompts,
|
||||||
negative_prompt=negative_prompts,
|
negative_prompt=negative_prompts,
|
||||||
num_inference_steps=self.num_inference_steps,
|
num_inference_steps=self.num_inference_steps,
|
||||||
|
@ -247,6 +272,7 @@ class SampleGenerator:
|
||||||
del tfimage
|
del tfimage
|
||||||
del batch_images
|
del batch_images
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
|
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
|
||||||
|
|
Loading…
Reference in New Issue