This commit is contained in:
Victor Hall 2023-03-04 15:09:31 -05:00
commit 2e3d044ba3
7 changed files with 124 additions and 59 deletions

View File

@ -116,7 +116,7 @@ class EveryDreamValidator:
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng():
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1)
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader):

View File

@ -1,7 +1,6 @@
import json
import logging
import os
import random
import typing
import zipfile
import argparse
@ -18,8 +17,7 @@ class DataResolver:
"""
self.aspects = args.aspects
self.flip_p = args.flip_p
self.seed = args.seed
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
"""
Get the list of `ImageTrainItem` for the given data root.
@ -116,8 +114,7 @@ class DirectoryResolver(DataResolver):
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
items = []
multipliers = {}
randomizer = random.Random(self.seed)
for pathname in tqdm.tqdm(image_paths):
current_dir = os.path.dirname(pathname)

View File

@ -28,6 +28,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
"scheduler": "dpm++",
"num_inference_steps": 15,
"show_progress_bars": true,
"generate_samples_every_n_steps": 200,
"generate_pretrain_samples": true,
"samples": [
{
"prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background",
@ -35,7 +37,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
},
{
"prompt": "a photograph of ted bennet riding a bicycle",
"seed": -1
"seed": -1,
"aspect_ratio": 1.77778
},
{
"random_caption": true,
@ -45,9 +48,11 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
}
```
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process.
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. If you want to see sample progress bars you can set `show_progress_bars` to `true`. To generate a batch of samples before training begins, set `generate_pretrain_samples` to true.
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64). Use `"random_caption": true` to pick a random caption from the training set each time.
Finally, you can override the `sample_steps` set in the main configuration .json file (or CLI) by setting `generate_samples_every_n_steps`. This value is read every time samples are updated, so if you initially pass `--sample_steps 200` and then later on you edit your `sample_prompts.json` file to add `"generate_samples_every_n_steps": 100`, after the next set of samples is generated you will start seeing new sets of image samples every 100 steps instead of only every 200 steps.
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64) or `aspect_ratio` (eg 1.77778 for 16:9). Use `"random_caption": true` to pick a random caption from the training set each time.
## LR

View File

@ -34,6 +34,8 @@ Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the
LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg.
The text encoder LR can run at a different value to the Unet LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json`. For example, to train the text encoder with an LR that is half that of the Unet, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `1.0`, meaning the text encoder and Unet are trained with the same LR.
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here.
Note `lion` does not use epsilon.

View File

@ -5,11 +5,13 @@
"lr": "learning rate, if null wil use CLI or main JSON config value",
"betas": "exponential decay rates for the moment estimates",
"epsilon": "value added to denominator for numerical stability, unused for lion",
"weight_decay": "weight decay (L2 penalty)"
"weight_decay": "weight decay (L2 penalty)",
"text_encoder_lr_scale": "scale the text encoder LR relative to the Unet LR. for example, if `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`."
},
"optimizer": "adamw8bit",
"lr": 1e-6,
"betas": [0.9, 0.999],
"epsilon": 1e-8,
"weight_decay": 0.010
"weight_decay": 0.010,
"text_encoder_lr_scale": 1.0
}

View File

@ -125,12 +125,12 @@ def setup_local_logger(args):
return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr):
"""
logs the optimizer settings
"""
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
"""
@ -363,7 +363,9 @@ def main(args):
else:
from tqdm.auto import tqdm
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
if args.seed == -1:
args.seed = random.randint(0, 2**30)
seed = args.seed
logging.info(f" Seed: {seed}")
set_seed(seed)
if torch.cuda.is_available():
@ -477,16 +479,6 @@ def main(args):
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
optimizer_config = None
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
@ -514,6 +506,7 @@ def main(args):
default_lr = 1e-6
curr_lr = args.lr
text_encoder_lr_scale = 1.0
if optimizer_config is not None:
betas = optimizer_config["betas"]
@ -524,12 +517,33 @@ def main(args):
if args.lr is not None:
curr_lr = args.lr
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
if text_encoder_lr_scale != 1.0:
logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}")
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
if curr_lr is None:
curr_lr = default_lr
logging.warning(f"No LR setting found, defaulting to {default_lr}")
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
if text_encoder_lr_scale != 1:
logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the "
f"Unet LR {curr_lr} for the text encoder instead *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = [{'params': unet.parameters()},
{'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}]
if optimizer_name:
if optimizer_name == "lion":
from lion_pytorch import Lion
@ -556,7 +570,7 @@ def main(args):
amsgrad=False,
)
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr)
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
image_train_items = resolve_image_train_items(args, log_folder)
@ -609,6 +623,7 @@ def main(args):
default_resolution=args.resolution, default_seed=args.seed,
config_file_path=args.sample_prompts,
batch_size=max(1,args.batch_size//2),
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers)
"""
@ -681,7 +696,7 @@ def main(args):
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
epoch_times = []
@ -742,11 +757,37 @@ def main(args):
return model_pred, target
def generate_samples(global_step: int, batch):
with isolate_rng():
prev_sample_steps = sample_generator.sample_steps
sample_generator.reload_config()
if prev_sample_steps != sample_generator.sample_steps:
next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps
print(f" * SampleGenerator config changed, now generating images samples every " +
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation_if_appropriate(epoch=0, global_step=0,
get_model_prediction_and_target_callable=get_model_prediction_and_target)
# the sample generator might be configured to generate samples before step 0
if sample_generator.generate_pretrain_samples:
_, batch = next(enumerate(train_dataloader))
generate_samples(global_step=0, batch=batch)
try:
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
@ -756,7 +797,7 @@ def main(args):
images_per_sec_log_step = []
epoch_len = math.ceil(len(train_batch) / args.batch_size)
steps_pbar = tqdm(range(epoch_len), position=1)
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
for step, batch in enumerate(train_dataloader):
@ -802,8 +843,13 @@ def main(args):
curr_lr = lr_scheduler.get_last_lr()[0]
loss_local = sum(loss_log_step) / len(loss_log_step)
loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
if args.disable_textenc_training or args.disable_unet_training or text_encoder_lr_scale == 1:
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
else:
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step)
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)
@ -814,21 +860,8 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
if (global_step + 1) % args.sample_steps == 0:
with isolate_rng():
sample_generator.reload_config()
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
if (global_step + 1) % sample_generator.sample_steps == 0:
generate_samples(global_step=global_step, batch=batch)
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60

View File

@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm.auto import tqdm
def clean_filename(filename):
@ -52,6 +53,15 @@ def chunk_list(l: list, batch_size: int,
yield b[i:i + batch_size]
def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]:
sizes = []
target_pixel_count = default_resolution * default_resolution
for w in range(256, 1024, 64):
for h in range(256, 1024, 64):
if abs((w * h) - target_pixel_count) <= 128 * 64:
sizes.append((w, h))
best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1]))))
return best_size
class SampleGenerator:
@ -73,6 +83,7 @@ class SampleGenerator:
config_file_path: str,
batch_size: int,
default_seed: int,
default_sample_steps: int,
use_xformers: bool):
self.log_folder = log_folder
self.log_writer = log_writer
@ -80,12 +91,15 @@ class SampleGenerator:
self.config_file_path = config_file_path
self.use_xformers = use_xformers
self.show_progress_bars = False
self.generate_pretrain_samples = False
self.default_resolution = default_resolution
self.default_seed = default_seed
self.sample_steps = default_sample_steps
self.sample_requests = None
self.reload_config()
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/")
@ -102,8 +116,12 @@ class SampleGenerator:
logging.warning(
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
logging.warning(
f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.")
self.sample_requests = self._make_random_caption_sample_requests()
f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated."
)
if self.sample_requests == None:
logging.warning(
f" Will generate samples from random training image captions until the problem is fixed.")
self.sample_requests = self._make_random_caption_sample_requests()
def update_random_captions(self, possible_captions: list[str]):
random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption]
@ -139,18 +157,20 @@ class SampleGenerator:
self.scheduler = config.get('scheduler', self.scheduler)
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
sample_requests_json = config.get('samples', None)
if sample_requests_json is None:
self.sample_requests = []
self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples)
self.sample_steps = config.get('generate_samples_every_n_steps', self.sample_steps)
sample_requests_config = config.get('samples', None)
if sample_requests_config is None:
self.sample_requests = self._make_random_caption_sample_requests()
else:
default_seed = config.get('seed', self.default_seed)
default_size = (self.default_resolution, self.default_resolution)
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
negative_prompt=p.get('negative_prompt', ''),
seed=p.get('seed', default_seed),
size=tuple(p.get('size', default_size)),
size=tuple(p.get('size', None) or
get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)),
wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_json]
) for p in sample_requests_config]
if len(self.sample_requests) == 0:
self._make_random_caption_sample_requests()
@ -159,23 +179,26 @@ class SampleGenerator:
"""
generates samples at different cfg scales and saves them to disk
"""
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
disable_progress_bars = not self.show_progress_bars
try:
font = ImageFont.truetype(font="arial.ttf", size=20)
except:
font = ImageFont.load_default()
if not self.show_progress_bars:
print(f" * Generating samples at gs:{global_step} for {len(self.sample_requests)} prompts")
sample_index = 0
with autocast():
batch: list[SampleRequest]
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
return a.size == b.size
for batch in chunk_list(self.sample_requests, self.batch_size,
compatibility_test=sample_compatibility_test):
#print("batch: ", batch)
batches = list(chunk_list(self.sample_requests, self.batch_size,
compatibility_test=sample_compatibility_test))
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
for batch in batches:
prompts = [p.prompt for p in batch]
negative_prompts = [p.negative_prompt for p in batch]
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
@ -186,6 +209,8 @@ class SampleGenerator:
batch_images = []
for cfg in self.cfgs:
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
images = pipe(prompt=prompts,
negative_prompt=negative_prompts,
num_inference_steps=self.num_inference_steps,
@ -247,6 +272,7 @@ class SampleGenerator:
del tfimage
del batch_images
pbar.update(1)
@torch.no_grad()
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):