Merge pull request #103 from damian0815/feat_text_encoder_LR
Separate text encoder LR; logging & pbar improvements; misc sample generator tweaks
This commit is contained in:
commit
32585a74b2
|
@ -105,7 +105,7 @@ class EveryDreamValidator:
|
|||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
with torch.no_grad(), isolate_rng():
|
||||
loss_validation_epoch = []
|
||||
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
||||
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(dataloader):
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import typing
|
||||
import zipfile
|
||||
import argparse
|
||||
|
@ -18,8 +17,7 @@ class DataResolver:
|
|||
"""
|
||||
self.aspects = args.aspects
|
||||
self.flip_p = args.flip_p
|
||||
self.seed = args.seed
|
||||
|
||||
|
||||
def image_train_items(self, data_root: str) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Get the list of `ImageTrainItem` for the given data root.
|
||||
|
@ -116,8 +114,7 @@ class DirectoryResolver(DataResolver):
|
|||
image_paths = list(DirectoryResolver.recurse_data_root(data_root))
|
||||
items = []
|
||||
multipliers = {}
|
||||
randomizer = random.Random(self.seed)
|
||||
|
||||
|
||||
for pathname in tqdm.tqdm(image_paths):
|
||||
current_dir = os.path.dirname(pathname)
|
||||
|
||||
|
|
|
@ -28,6 +28,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
|
|||
"scheduler": "dpm++",
|
||||
"num_inference_steps": 15,
|
||||
"show_progress_bars": true,
|
||||
"generate_samples_every_n_steps": 200,
|
||||
"generate_pretrain_samples": true,
|
||||
"samples": [
|
||||
{
|
||||
"prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background",
|
||||
|
@ -35,7 +37,8 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
|
|||
},
|
||||
{
|
||||
"prompt": "a photograph of ted bennet riding a bicycle",
|
||||
"seed": -1
|
||||
"seed": -1,
|
||||
"aspect_ratio": 1.77778
|
||||
},
|
||||
{
|
||||
"random_caption": true,
|
||||
|
@ -45,9 +48,11 @@ In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, w
|
|||
}
|
||||
```
|
||||
|
||||
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process.
|
||||
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. If you want to see sample progress bars you can set `show_progress_bars` to `true`. To generate a batch of samples before training begins, set `generate_pretrain_samples` to true.
|
||||
|
||||
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64). Use `"random_caption": true` to pick a random caption from the training set each time.
|
||||
Finally, you can override the `sample_steps` set in the main configuration .json file (or CLI) by setting `generate_samples_every_n_steps`. This value is read every time samples are updated, so if you initially pass `--sample_steps 200` and then later on you edit your `sample_prompts.json` file to add `"generate_samples_every_n_steps": 100`, after the next set of samples is generated you will start seeing new sets of image samples every 100 steps instead of only every 200 steps.
|
||||
|
||||
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64) or `aspect_ratio` (eg 1.77778 for 16:9). Use `"random_caption": true` to pick a random caption from the training set each time.
|
||||
|
||||
## LR
|
||||
|
||||
|
|
|
@ -34,6 +34,8 @@ Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the
|
|||
|
||||
LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg.
|
||||
|
||||
The text encoder LR can run at a different value to the Unet LR. This may help prevent over-fitting, especially if you're training from SD2 checkpoints. To set the text encoder LR, add a value for `text_encoder_lr_scale` to `optimizer.json`. For example, to train the text encoder with an LR that is half that of the Unet, add `"text_encoder_lr_scale": 0.5` to `optimizer.json`. The default value is `1.0`, meaning the text encoder and Unet are trained with the same LR.
|
||||
|
||||
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here.
|
||||
|
||||
Note `lion` does not use epsilon.
|
|
@ -5,11 +5,13 @@
|
|||
"lr": "learning rate, if null wil use CLI or main JSON config value",
|
||||
"betas": "exponential decay rates for the moment estimates",
|
||||
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
||||
"weight_decay": "weight decay (L2 penalty)"
|
||||
"weight_decay": "weight decay (L2 penalty)",
|
||||
"text_encoder_lr_scale": "scale the text encoder LR relative to the Unet LR. for example, if `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`."
|
||||
},
|
||||
"optimizer": "adamw8bit",
|
||||
"lr": 1e-6,
|
||||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.010
|
||||
"weight_decay": 0.010,
|
||||
"text_encoder_lr_scale": 1.0
|
||||
}
|
||||
|
|
99
train.py
99
train.py
|
@ -125,12 +125,12 @@ def setup_local_logger(args):
|
|||
|
||||
return datetimestamp
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
|
||||
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
|
@ -363,7 +363,9 @@ def main(args):
|
|||
else:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
if args.seed == -1:
|
||||
args.seed = random.randint(0, 2**30)
|
||||
seed = args.seed
|
||||
logging.info(f" Seed: {seed}")
|
||||
set_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
|
@ -477,16 +479,6 @@ def main(args):
|
|||
else:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
optimizer_config = None
|
||||
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
|
||||
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
|
||||
|
@ -514,6 +506,7 @@ def main(args):
|
|||
|
||||
default_lr = 1e-6
|
||||
curr_lr = args.lr
|
||||
text_encoder_lr_scale = 1.0
|
||||
|
||||
if optimizer_config is not None:
|
||||
betas = optimizer_config["betas"]
|
||||
|
@ -524,12 +517,33 @@ def main(args):
|
|||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||
|
||||
text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
|
||||
if text_encoder_lr_scale != 1.0:
|
||||
logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}")
|
||||
|
||||
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
||||
|
||||
if curr_lr is None:
|
||||
curr_lr = default_lr
|
||||
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
||||
|
||||
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||
if text_encoder_lr_scale != 1:
|
||||
logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the "
|
||||
f"Unet LR {curr_lr} for the text encoder instead *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = [{'params': unet.parameters()},
|
||||
{'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}]
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
from lion_pytorch import Lion
|
||||
|
@ -556,7 +570,7 @@ def main(args):
|
|||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr)
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
|
@ -609,6 +623,7 @@ def main(args):
|
|||
default_resolution=args.resolution, default_seed=args.seed,
|
||||
config_file_path=args.sample_prompts,
|
||||
batch_size=max(1,args.batch_size//2),
|
||||
default_sample_steps=args.sample_steps,
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers)
|
||||
|
||||
"""
|
||||
|
@ -681,7 +696,7 @@ def main(args):
|
|||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True)
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
epoch_times = []
|
||||
|
||||
|
@ -742,11 +757,37 @@ def main(args):
|
|||
|
||||
return model_pred, target
|
||||
|
||||
def generate_samples(global_step: int, batch):
|
||||
with isolate_rng():
|
||||
prev_sample_steps = sample_generator.sample_steps
|
||||
sample_generator.reload_config()
|
||||
if prev_sample_steps != sample_generator.sample_steps:
|
||||
next_sample_step = math.ceil((global_step + 1) / sample_generator.sample_steps) * sample_generator.sample_steps
|
||||
print(f" * SampleGenerator config changed, now generating images samples every " +
|
||||
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch=0, global_step=0,
|
||||
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
||||
|
||||
# the sample generator might be configured to generate samples before step 0
|
||||
if sample_generator.generate_pretrain_samples:
|
||||
_, batch = next(enumerate(train_dataloader))
|
||||
generate_samples(global_step=0, batch=batch)
|
||||
|
||||
try:
|
||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||
|
||||
|
@ -756,7 +797,7 @@ def main(args):
|
|||
images_per_sec_log_step = []
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
@ -802,8 +843,13 @@ def main(args):
|
|||
curr_lr = lr_scheduler.get_last_lr()[0]
|
||||
loss_local = sum(loss_log_step) / len(loss_log_step)
|
||||
loss_log_step = []
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
if args.disable_textenc_training or args.disable_unet_training or text_encoder_lr_scale == 1:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||
else:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step)
|
||||
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
|
||||
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
|
@ -814,21 +860,8 @@ def main(args):
|
|||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if (global_step + 1) % args.sample_steps == 0:
|
||||
with isolate_rng():
|
||||
sample_generator.reload_config()
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if (global_step + 1) % sample_generator.sample_steps == 0:
|
||||
generate_samples(global_step=global_step, batch=batch)
|
||||
|
||||
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistep
|
|||
from torch.cuda.amp import autocast
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def clean_filename(filename):
|
||||
|
@ -52,6 +53,15 @@ def chunk_list(l: list, batch_size: int,
|
|||
yield b[i:i + batch_size]
|
||||
|
||||
|
||||
def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]:
|
||||
sizes = []
|
||||
target_pixel_count = default_resolution * default_resolution
|
||||
for w in range(256, 1024, 64):
|
||||
for h in range(256, 1024, 64):
|
||||
if abs((w * h) - target_pixel_count) <= 128 * 64:
|
||||
sizes.append((w, h))
|
||||
best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1]))))
|
||||
return best_size
|
||||
|
||||
|
||||
class SampleGenerator:
|
||||
|
@ -73,6 +83,7 @@ class SampleGenerator:
|
|||
config_file_path: str,
|
||||
batch_size: int,
|
||||
default_seed: int,
|
||||
default_sample_steps: int,
|
||||
use_xformers: bool):
|
||||
self.log_folder = log_folder
|
||||
self.log_writer = log_writer
|
||||
|
@ -80,12 +91,15 @@ class SampleGenerator:
|
|||
self.config_file_path = config_file_path
|
||||
self.use_xformers = use_xformers
|
||||
self.show_progress_bars = False
|
||||
self.generate_pretrain_samples = False
|
||||
|
||||
self.default_resolution = default_resolution
|
||||
self.default_seed = default_seed
|
||||
self.sample_steps = default_sample_steps
|
||||
|
||||
self.sample_requests = None
|
||||
self.reload_config()
|
||||
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
|
||||
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
|
||||
if not os.path.exists(f"{log_folder}/samples/"):
|
||||
os.makedirs(f"{log_folder}/samples/")
|
||||
|
||||
|
@ -102,8 +116,12 @@ class SampleGenerator:
|
|||
logging.warning(
|
||||
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
|
||||
logging.warning(
|
||||
f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.")
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated."
|
||||
)
|
||||
if self.sample_requests == None:
|
||||
logging.warning(
|
||||
f" Will generate samples from random training image captions until the problem is fixed.")
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
def update_random_captions(self, possible_captions: list[str]):
|
||||
random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption]
|
||||
|
@ -139,18 +157,20 @@ class SampleGenerator:
|
|||
self.scheduler = config.get('scheduler', self.scheduler)
|
||||
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
|
||||
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
|
||||
sample_requests_json = config.get('samples', None)
|
||||
if sample_requests_json is None:
|
||||
self.sample_requests = []
|
||||
self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples)
|
||||
self.sample_steps = config.get('generate_samples_every_n_steps', self.sample_steps)
|
||||
sample_requests_config = config.get('samples', None)
|
||||
if sample_requests_config is None:
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
else:
|
||||
default_seed = config.get('seed', self.default_seed)
|
||||
default_size = (self.default_resolution, self.default_resolution)
|
||||
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
|
||||
negative_prompt=p.get('negative_prompt', ''),
|
||||
seed=p.get('seed', default_seed),
|
||||
size=tuple(p.get('size', default_size)),
|
||||
size=tuple(p.get('size', None) or
|
||||
get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)),
|
||||
wants_random_caption=p.get('random_caption', False)
|
||||
) for p in sample_requests_json]
|
||||
) for p in sample_requests_config]
|
||||
if len(self.sample_requests) == 0:
|
||||
self._make_random_caption_sample_requests()
|
||||
|
||||
|
@ -159,23 +179,26 @@ class SampleGenerator:
|
|||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
|
||||
|
||||
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
|
||||
disable_progress_bars = not self.show_progress_bars
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype(font="arial.ttf", size=20)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
if not self.show_progress_bars:
|
||||
print(f" * Generating samples at gs:{global_step} for {len(self.sample_requests)} prompts")
|
||||
|
||||
sample_index = 0
|
||||
with autocast():
|
||||
batch: list[SampleRequest]
|
||||
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
|
||||
return a.size == b.size
|
||||
for batch in chunk_list(self.sample_requests, self.batch_size,
|
||||
compatibility_test=sample_compatibility_test):
|
||||
#print("batch: ", batch)
|
||||
batches = list(chunk_list(self.sample_requests, self.batch_size,
|
||||
compatibility_test=sample_compatibility_test))
|
||||
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
|
||||
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
||||
for batch in batches:
|
||||
prompts = [p.prompt for p in batch]
|
||||
negative_prompts = [p.negative_prompt for p in batch]
|
||||
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
||||
|
@ -186,6 +209,8 @@ class SampleGenerator:
|
|||
|
||||
batch_images = []
|
||||
for cfg in self.cfgs:
|
||||
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
|
||||
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
|
||||
images = pipe(prompt=prompts,
|
||||
negative_prompt=negative_prompts,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
|
@ -247,6 +272,7 @@ class SampleGenerator:
|
|||
del tfimage
|
||||
del batch_images
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
|
||||
|
|
Loading…
Reference in New Issue