This commit is contained in:
Victor Hall 2023-03-04 15:09:31 -05:00
commit 2e3d044ba3
7 changed files with 124 additions and 59 deletions

View File

@ -116,7 +116,7 @@ class EveryDreamValidator:
[Any, Any], tuple[torch.Tensor, torch.Tensor]]): [Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng(): with torch.no_grad(), isolate_rng():
loss_validation_epoch = [] loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1) steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}") steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader): for step, batch in enumerate(dataloader):

View File

@ -1,7 +1,6 @@
import json import json
import logging import logging
import os import os
import random
import typing import typing
import zipfile import zipfile
import argparse import argparse
@ -18,7 +17,6 @@ class DataResolver:
""" """
self.aspects = args.aspects self.aspects = args.aspects
self.flip_p = args.flip_p self.flip_p = args.flip_p
self.seed = args.seed
def image_train_items(self, data_root: str) -> list[ImageTrainItem]: def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
""" """
@ -116,7 +114,6 @@ class DirectoryResolver(DataResolver):
image_paths = list(DirectoryResolver.recurse_data_root(data_root)) image_paths = list(DirectoryResolver.recurse_data_root(data_root))
items = [] items = []
multipliers = {} multipliers = {}
randomizer = random.Random(self.seed)
for pathname in tqdm.tqdm(image_paths): for pathname in tqdm.tqdm(image_paths):
current_dir = os.path.dirname(pathname) current_dir = os.path.dirname(pathname)

View File

@ -28,6 +28,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
"scheduler": "dpm++", "scheduler": "dpm++",
"num_inference_steps": 15, "num_inference_steps": 15,
"show_progress_bars": true, "show_progress_bars": true,
"generate_samples_every_n_steps": 200,
"generate_pretrain_samples": true,
"samples": [ "samples": [
{ {
"prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background", "prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background",
@ -35,7 +37,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
}, },
{ {
"prompt": "a photograph of ted bennet riding a bicycle", "prompt": "a photograph of ted bennet riding a bicycle",
"seed": -1 "seed": -1,
"aspect_ratio": 1.77778
}, },
{ {
"random_caption": true, "random_caption": true,
@ -45,9 +48,11 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
} }
``` ```
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process. At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. If you want to see sample progress bars you can set `show_progress_bars` to `true`. To generate a batch of samples before training begins, set `generate_pretrain_samples` to true.
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64). Use `"random_caption": true` to pick a random caption from the training set each time. Finally, you can override the `sample_steps` set in the main configuration .json file (or CLI) by setting `generate_samples_every_n_steps`. This value is read every time samples are updated, so if you initially pass `--sample_steps 200` and then later on you edit your `sample_prompts.json` file to add `"generate_samples_every_n_steps": 100`, after the next set of samples is generated you will start seeing new sets of image samples every 100 steps instead of only every 200 steps.
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64) or `aspect_ratio` (eg 1.77778 for 16:9). Use `"random_caption": true` to pick a random caption from the training set each time.
## LR ## LR

View File

@ -34,6 +34,8 @@ Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the
LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg. LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg.
The text encoder LR can run at a different value to the Unet LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json`. For example, to train the text encoder with an LR that is half that of the Unet, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `1.0`, meaning the text encoder and Unet are trained with the same LR.
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here. Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here.
Note `lion` does not use epsilon. Note `lion` does not use epsilon.

View File

@ -5,11 +5,13 @@
"lr": "learning rate, if null wil use CLI or main JSON config value", "lr": "learning rate, if null wil use CLI or main JSON config value",
"betas": "exponential decay rates for the moment estimates", "betas": "exponential decay rates for the moment estimates",
"epsilon": "value added to denominator for numerical stability, unused for lion", "epsilon": "value added to denominator for numerical stability, unused for lion",
"weight_decay": "weight decay (L2 penalty)" "weight_decay": "weight decay (L2 penalty)",
"text_encoder_lr_scale": "scale the text encoder LR relative to the Unet LR. for example, if `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`."
}, },
"optimizer": "adamw8bit", "optimizer": "adamw8bit",
"lr": 1e-6, "lr": 1e-6,
"betas": [0.9, 0.999], "betas": [0.9, 0.999],
"epsilon": 1e-8, "epsilon": 1e-8,
"weight_decay": 0.010 "weight_decay": 0.010,
"text_encoder_lr_scale": 1.0
} }

View File

@ -125,12 +125,12 @@ def setup_local_logger(args):
return datetimestamp return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr): def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr):
""" """
logs the optimizer settings logs the optimizer settings
""" """
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
def save_optimizer(optimizer: torch.optim.Optimizer, path: str): def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
""" """
@ -363,7 +363,9 @@ def main(args):
else: else:
from tqdm.auto import tqdm from tqdm.auto import tqdm
seed = args.seed if args.seed != -1 else random.randint(0, 2**30) if args.seed == -1:
args.seed = random.randint(0, 2**30)
seed = args.seed
logging.info(f" Seed: {seed}") logging.info(f" Seed: {seed}")
set_seed(seed) set_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -477,16 +479,6 @@ def main(args):
else: else:
text_encoder = text_encoder.to(device, dtype=torch.float32) text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
optimizer_config = None optimizer_config = None
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json" optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)): if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
@ -514,6 +506,7 @@ def main(args):
default_lr = 1e-6 default_lr = 1e-6
curr_lr = args.lr curr_lr = args.lr
text_encoder_lr_scale = 1.0
if optimizer_config is not None: if optimizer_config is not None:
betas = optimizer_config["betas"] betas = optimizer_config["betas"]
@ -524,12 +517,33 @@ def main(args):
if args.lr is not None: if args.lr is not None:
curr_lr = args.lr curr_lr = args.lr
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}") logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
if text_encoder_lr_scale != 1.0:
logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}")
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *") logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
if curr_lr is None: if curr_lr is None:
curr_lr = default_lr curr_lr = default_lr
logging.warning(f"No LR setting found, defaulting to {default_lr}") logging.warning(f"No LR setting found, defaulting to {default_lr}")
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
if text_encoder_lr_scale != 1:
logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the "
f"Unet LR {curr_lr} for the text encoder instead *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = [{'params': unet.parameters()},
{'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}]
if optimizer_name: if optimizer_name:
if optimizer_name == "lion": if optimizer_name == "lion":
from lion_pytorch import Lion from lion_pytorch import Lion
@ -556,7 +570,7 @@ def main(args):
amsgrad=False, amsgrad=False,
) )
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr) log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
image_train_items = resolve_image_train_items(args, log_folder) image_train_items = resolve_image_train_items(args, log_folder)
@ -609,6 +623,7 @@ def main(args):
default_resolution=args.resolution, default_seed=args.seed, default_resolution=args.resolution, default_seed=args.seed,
config_file_path=args.sample_prompts, config_file_path=args.sample_prompts,
batch_size=max(1,args.batch_size//2), batch_size=max(1,args.batch_size//2),
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers) use_xformers=is_xformers_available() and not args.disable_xformers)
""" """
@ -681,7 +696,7 @@ def main(args):
) )
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True) epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
epoch_times = [] epoch_times = []
@ -742,11 +757,37 @@ def main(args):
return model_pred, target return model_pred, target
def generate_samples(global_step: int, batch):
with isolate_rng():
prev_sample_steps = sample_generator.sample_steps
sample_generator.reload_config()
if prev_sample_steps != sample_generator.sample_steps:
next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps
print(f" * SampleGenerator config changed, now generating images samples every " +
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
# Pre-train validation to establish a starting point on the loss graph # Pre-train validation to establish a starting point on the loss graph
if validator: if validator:
validator.do_validation_if_appropriate(epoch=0, global_step=0, validator.do_validation_if_appropriate(epoch=0, global_step=0,
get_model_prediction_and_target_callable=get_model_prediction_and_target) get_model_prediction_and_target_callable=get_model_prediction_and_target)
# the sample generator might be configured to generate samples before step 0
if sample_generator.generate_pretrain_samples:
_, batch = next(enumerate(train_dataloader))
generate_samples(global_step=0, batch=batch)
try: try:
write_batch_schedule(args, log_folder, train_batch, epoch = 0) write_batch_schedule(args, log_folder, train_batch, epoch = 0)
@ -756,7 +797,7 @@ def main(args):
images_per_sec_log_step = [] images_per_sec_log_step = []
epoch_len = math.ceil(len(train_batch) / args.batch_size) epoch_len = math.ceil(len(train_batch) / args.batch_size)
steps_pbar = tqdm(range(epoch_len), position=1) steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}") steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
@ -803,7 +844,12 @@ def main(args):
loss_local = sum(loss_log_step) / len(loss_log_step) loss_local = sum(loss_log_step) / len(loss_log_step)
loss_log_step = [] loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
if args.disable_textenc_training or args.disable_unet_training or text_encoder_lr_scale == 1:
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
else:
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step)
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step) log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step) sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step) avg = sum_img / len(images_per_sec_log_step)
@ -814,21 +860,8 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if (global_step + 1) % args.sample_steps == 0: if (global_step + 1) % sample_generator.sample_steps == 0:
with isolate_rng(): generate_samples(global_step=global_step, batch=batch)
sample_generator.reload_config()
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60 min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60

View File

@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm
def clean_filename(filename): def clean_filename(filename):
@ -52,6 +53,15 @@ def chunk_list(l: list, batch_size: int,
yield b[i:i + batch_size] yield b[i:i + batch_size]
def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]:
sizes = []
target_pixel_count = default_resolution * default_resolution
for w in range(256, 1024, 64):
for h in range(256, 1024, 64):
if abs((w * h) - target_pixel_count) <= 128 * 64:
sizes.append((w, h))
best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1]))))
return best_size
class SampleGenerator: class SampleGenerator:
@ -73,6 +83,7 @@ class SampleGenerator:
config_file_path: str, config_file_path: str,
batch_size: int, batch_size: int,
default_seed: int, default_seed: int,
default_sample_steps: int,
use_xformers: bool): use_xformers: bool):
self.log_folder = log_folder self.log_folder = log_folder
self.log_writer = log_writer self.log_writer = log_writer
@ -80,12 +91,15 @@ class SampleGenerator:
self.config_file_path = config_file_path self.config_file_path = config_file_path
self.use_xformers = use_xformers self.use_xformers = use_xformers
self.show_progress_bars = False self.show_progress_bars = False
self.generate_pretrain_samples = False
self.default_resolution = default_resolution self.default_resolution = default_resolution
self.default_seed = default_seed self.default_seed = default_seed
self.sample_steps = default_sample_steps
self.sample_requests = None
self.reload_config() self.reload_config()
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps") print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
if not os.path.exists(f"{log_folder}/samples/"): if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/") os.makedirs(f"{log_folder}/samples/")
@ -102,7 +116,11 @@ class SampleGenerator:
logging.warning( logging.warning(
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}") f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
logging.warning( logging.warning(
f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.") f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated."
)
if self.sample_requests == None:
logging.warning(
f" Will generate samples from random training image captions until the problem is fixed.")
self.sample_requests = self._make_random_caption_sample_requests() self.sample_requests = self._make_random_caption_sample_requests()
def update_random_captions(self, possible_captions: list[str]): def update_random_captions(self, possible_captions: list[str]):
@ -139,18 +157,20 @@ class SampleGenerator:
self.scheduler = config.get('scheduler', self.scheduler) self.scheduler = config.get('scheduler', self.scheduler)
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps) self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars) self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
sample_requests_json = config.get('samples', None) self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples)
if sample_requests_json is None: self.sample_steps = config.get('generate_samples_every_n_steps', self.sample_steps)
self.sample_requests = [] sample_requests_config = config.get('samples', None)
if sample_requests_config is None:
self.sample_requests = self._make_random_caption_sample_requests()
else: else:
default_seed = config.get('seed', self.default_seed) default_seed = config.get('seed', self.default_seed)
default_size = (self.default_resolution, self.default_resolution)
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''), self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
negative_prompt=p.get('negative_prompt', ''), negative_prompt=p.get('negative_prompt', ''),
seed=p.get('seed', default_seed), seed=p.get('seed', default_seed),
size=tuple(p.get('size', default_size)), size=tuple(p.get('size', None) or
get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)),
wants_random_caption=p.get('random_caption', False) wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_json] ) for p in sample_requests_config]
if len(self.sample_requests) == 0: if len(self.sample_requests) == 0:
self._make_random_caption_sample_requests() self._make_random_caption_sample_requests()
@ -159,23 +179,26 @@ class SampleGenerator:
""" """
generates samples at different cfg scales and saves them to disk generates samples at different cfg scales and saves them to disk
""" """
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}") disable_progress_bars = not self.show_progress_bars
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
try: try:
font = ImageFont.truetype(font="arial.ttf", size=20) font = ImageFont.truetype(font="arial.ttf", size=20)
except: except:
font = ImageFont.load_default() font = ImageFont.load_default()
if not self.show_progress_bars:
print(f" * Generating samples at gs:{global_step} for {len(self.sample_requests)} prompts")
sample_index = 0 sample_index = 0
with autocast(): with autocast():
batch: list[SampleRequest] batch: list[SampleRequest]
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool: def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
return a.size == b.size return a.size == b.size
for batch in chunk_list(self.sample_requests, self.batch_size, batches = list(chunk_list(self.sample_requests, self.batch_size,
compatibility_test=sample_compatibility_test): compatibility_test=sample_compatibility_test))
#print("batch: ", batch) pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
for batch in batches:
prompts = [p.prompt for p in batch] prompts = [p.prompt for p in batch]
negative_prompts = [p.negative_prompt for p in batch] negative_prompts = [p.negative_prompt for p in batch]
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
@ -186,6 +209,8 @@ class SampleGenerator:
batch_images = [] batch_images = []
for cfg in self.cfgs: for cfg in self.cfgs:
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
images = pipe(prompt=prompts, images = pipe(prompt=prompts,
negative_prompt=negative_prompts, negative_prompt=negative_prompts,
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
@ -247,6 +272,7 @@ class SampleGenerator:
del tfimage del tfimage
del batch_images del batch_images
pbar.update(1)
@torch.no_grad() @torch.no_grad()
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict): def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):