Merge pull request #75 from damian0815/sample_generation_refactor_redo
Refactor sample generation and introduce sample_prompts.json
This commit is contained in:
commit
c071230a49
|
@ -6,15 +6,48 @@ Everydream2 uses the Tensorboard library to log performance metrics. (more opti
|
|||
|
||||
You should launch tensorboard while your training is running and watch along.
|
||||
|
||||
tensorboard --logdir logs
|
||||
tensorboard --logdir logs --samples_per_plugin images=100
|
||||
|
||||
## Sample images
|
||||
|
||||
The trainer produces sample images from sample_prompts.txt with a fixed seed every so many steps as defined by your sample_steps argument. These are saved in the logs directory and can be viewed in tensorboard as well if you prefer. If you have a ton of them, the slider bar in tensorboard may not select them all, but the actual files are still stored in your logs as well for review.
|
||||
By default, the trainer produces sample images from `sample_prompts.txt` with a fixed seed every so many steps as defined by your `sample_steps` argument. These are saved in the logs directory and can be viewed in tensorboard as well if you prefer. If you have a ton of them, the slider bar in tensorboard may not select them all (unless you launch tensorboard with the `--samples_per_plugin` argument as shown above), but the actual files are still stored in your logs as well for review.
|
||||
|
||||
Samples are produced at CFG scales of 1, 4, and 7. You may find this very useful to see how your model is progressing.
|
||||
Samples are produced at CFG scales of 1, 4, and 7. You may find this very useful to see how your model is progressing.
|
||||
|
||||
If your sample_prompts.txt is empty, the trainer will generate from prompts from the last batch of your training data, up to 4 sets of samples.
|
||||
If your `sample_prompts.txt` is empty, the trainer will generate from prompts from the last batch of your training data, up to 4 sets of samples.
|
||||
|
||||
### More control
|
||||
|
||||
In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, which offers more control over sample generation. Here is an example `sample_prompts.json` file:
|
||||
|
||||
```json
|
||||
{
|
||||
"batch_size": 3,
|
||||
"seed": 555,
|
||||
"cfgs": [7, 4],
|
||||
"scheduler": "dpm++",
|
||||
"num_inference_steps": 15,
|
||||
"show_progress_bars": true,
|
||||
"samples": [
|
||||
{
|
||||
"prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background",
|
||||
"negative_prompt": "distorted, deformed"
|
||||
},
|
||||
{
|
||||
"prompt": "a photograph of ted bennet riding a bicycle",
|
||||
"seed": -1
|
||||
},
|
||||
{
|
||||
"random_caption": true,
|
||||
"size": [640, 384]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process.
|
||||
|
||||
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64). Use `"random_caption": true` to pick a random caption from the training set each time.
|
||||
|
||||
## LR
|
||||
|
||||
|
|
156
train.py
156
train.py
|
@ -31,17 +31,16 @@ import importlib
|
|||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from colorama import Fore, Style, Cursor
|
||||
from colorama import Fore, Style
|
||||
import numpy as np
|
||||
import itertools
|
||||
import torch
|
||||
import datetime
|
||||
import json
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, \
|
||||
DPMSolverMultistepScheduler
|
||||
#from diffusers.models import AttentionBlock
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
@ -58,20 +57,17 @@ from data.every_dream_validation import EveryDreamValidator
|
|||
from data.image_train_item import ImageTrainItem
|
||||
from utils.huggingface_downloader import try_download_model_from_hf
|
||||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from utils.gpu import GPU
|
||||
import data.aspects as aspects
|
||||
import data.resolver as resolver
|
||||
from utils.sample_generator import SampleGenerator
|
||||
|
||||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
||||
def clean_filename(filename):
|
||||
"""
|
||||
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||
"""
|
||||
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
|
||||
|
||||
def get_hf_ckpt_cache_path(ckpt_path):
|
||||
return os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||
|
||||
|
@ -425,108 +421,6 @@ def main(args):
|
|||
# logging.info(f" Saving optimizer state to {save_path}")
|
||||
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def __create_inference_pipe(unet, text_encoder, tokenizer, scheduler, vae):
|
||||
"""
|
||||
creates a pipeline for SD inference
|
||||
"""
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
|
||||
return pipe
|
||||
|
||||
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen):
|
||||
"""
|
||||
generates a single sample at a given cfg scale and saves it to disk
|
||||
"""
|
||||
with torch.no_grad(), autocast():
|
||||
image = pipe(prompt,
|
||||
num_inference_steps=30,
|
||||
num_images_per_prompt=1,
|
||||
guidance_scale=cfg,
|
||||
generator=gen,
|
||||
height=resolution,
|
||||
width=resolution,
|
||||
).images[0]
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
try:
|
||||
font = ImageFont.truetype(font="arial.ttf", size=20)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
print_msg = f"cfg:{cfg:.1f}"
|
||||
|
||||
l, t, r, b = draw.textbbox(xy=(0,0), text=print_msg, font=font)
|
||||
text_width = r - l
|
||||
text_height = b - t
|
||||
|
||||
x = float(image.width - text_width - 10)
|
||||
y = float(image.height - text_height - 10)
|
||||
|
||||
draw.rectangle((x, y, image.width, image.height), fill="white")
|
||||
draw.text((x, y), print_msg, fill="black", font=font)
|
||||
del draw, font
|
||||
return image
|
||||
|
||||
def __generate_test_samples(pipe, prompts, gs, log_writer, log_folder, random_captions=False, resolution=512):
|
||||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
logging.info(f"Generating samples gs:{gs}, for {prompts}")
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
|
||||
gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
i = 0
|
||||
for prompt in prompts:
|
||||
if prompt is None or len(prompt) < 2:
|
||||
#logging.warning("empty prompt in sample prompts, check your prompts file")
|
||||
continue
|
||||
images = []
|
||||
for cfg in [7.0, 4.0, 1.01]:
|
||||
image = __generate_sample(pipe, prompt, cfg, resolution=resolution, gen=gen)
|
||||
images.append(image)
|
||||
|
||||
width = 0
|
||||
height = 0
|
||||
for image in images:
|
||||
width += image.width
|
||||
height = max(height, image.height)
|
||||
|
||||
result = Image.new('RGB', (width, height))
|
||||
|
||||
x_offset = 0
|
||||
for image in images:
|
||||
result.paste(image, (x_offset, 0))
|
||||
x_offset += image.width
|
||||
|
||||
clean_prompt = clean_filename(prompt)
|
||||
|
||||
result.save(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
f.write(prompt)
|
||||
f.write(f"\n seed: {seed}")
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
if random_captions:
|
||||
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
||||
else:
|
||||
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
|
||||
i += 1
|
||||
|
||||
del result
|
||||
del tfimage
|
||||
del images
|
||||
|
||||
try:
|
||||
|
||||
# check for a local file
|
||||
|
@ -545,7 +439,7 @@ def main(args):
|
|||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||
sample_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
|
||||
except Exception as e:
|
||||
|
@ -679,6 +573,11 @@ def main(args):
|
|||
|
||||
log_args(log_writer, args)
|
||||
|
||||
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
|
||||
default_resolution=args.resolution, default_seed=args.seed,
|
||||
config_file_path=args.sample_prompts,
|
||||
batch_size=max(1,args.batch_size//2),
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers)
|
||||
|
||||
"""
|
||||
Train the model
|
||||
|
@ -832,6 +731,7 @@ def main(args):
|
|||
# # discard the grads, just want to pin memory
|
||||
# optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
write_batch_schedule(args, log_folder, train_batch, 0)
|
||||
|
||||
for epoch in range(args.max_epochs):
|
||||
|
@ -901,19 +801,18 @@ def main(args):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
if (global_step + 1) % args.sample_steps == 0:
|
||||
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
|
||||
pipe = pipe.to(device)
|
||||
with isolate_rng():
|
||||
sample_generator.reload_config()
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
with torch.no_grad():
|
||||
sample_prompts = read_sample_prompts(args.sample_prompts)
|
||||
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
|
||||
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
|
||||
else:
|
||||
max_prompts = min(4,len(batch["captions"]))
|
||||
prompts=batch["captions"][:max_prompts]
|
||||
__generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution)
|
||||
|
||||
del pipe
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -967,8 +866,9 @@ def main(args):
|
|||
|
||||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX} ^^ NO not doing that.{Style.RESET_ALL}")
|
||||
#save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
#__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
@ -1020,7 +920,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
|
||||
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
|
||||
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="Text file with prompts to generate test samples from, or JSON file with sample generator settings (default: sample_prompts.txt)")
|
||||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
import json
|
||||
import logging
|
||||
import os.path
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
from typing import Generator, Callable, Any
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from colorama import Fore, Style
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def clean_filename(filename):
|
||||
"""
|
||||
removes all non-alphanumeric characters from a string so it is safe to use as a filename
|
||||
"""
|
||||
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
|
||||
|
||||
@dataclass
|
||||
class SampleRequest:
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
seed: int
|
||||
size: tuple[int,int]
|
||||
wants_random_caption: bool = False
|
||||
|
||||
def __str__(self):
|
||||
rep = self.prompt
|
||||
if len(self.negative_prompt) > 0:
|
||||
rep += "\n negative prompt: {self.negative_prompt}"
|
||||
rep += f"\n seed: {self.seed}"
|
||||
return rep
|
||||
|
||||
|
||||
def chunk_list(l: list, batch_size: int,
|
||||
compatibility_test: Callable[[Any,Any], bool]=lambda x,y: True
|
||||
) -> Generator[list, None, None]:
|
||||
buckets = []
|
||||
for item in l:
|
||||
compatible_bucket = next((b for b in buckets if compatibility_test(item, b[0])), None)
|
||||
if compatible_bucket is not None:
|
||||
compatible_bucket.append(item)
|
||||
else:
|
||||
buckets.append([item])
|
||||
|
||||
for b in buckets:
|
||||
for i in range(0, len(b), batch_size):
|
||||
yield b[i:i + batch_size]
|
||||
|
||||
|
||||
|
||||
|
||||
class SampleGenerator:
|
||||
seed: int
|
||||
default_resolution: int
|
||||
cfgs: list[float] = [7, 4, 1.01]
|
||||
scheduler: str = 'ddim'
|
||||
num_inference_steps: int = 30
|
||||
random_captions = False
|
||||
|
||||
sample_requests: [str]
|
||||
log_folder: str
|
||||
log_writer: SummaryWriter
|
||||
|
||||
def __init__(self,
|
||||
log_folder: str,
|
||||
log_writer: SummaryWriter,
|
||||
default_resolution: int,
|
||||
config_file_path: str,
|
||||
batch_size: int,
|
||||
default_seed: int,
|
||||
use_xformers: bool):
|
||||
self.log_folder = log_folder
|
||||
self.log_writer = log_writer
|
||||
self.batch_size = batch_size
|
||||
self.config_file_path = config_file_path
|
||||
self.use_xformers = use_xformers
|
||||
self.show_progress_bars = False
|
||||
|
||||
self.default_resolution = default_resolution
|
||||
self.default_seed = default_seed
|
||||
|
||||
self.reload_config()
|
||||
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
|
||||
if not os.path.exists(f"{log_folder}/samples/"):
|
||||
os.makedirs(f"{log_folder}/samples/")
|
||||
|
||||
def reload_config(self):
|
||||
try:
|
||||
config_file_extension = os.path.splitext(self.config_file_path)[1].lower()
|
||||
if config_file_extension == '.txt':
|
||||
self._reload_sample_prompts_txt(self.config_file_path)
|
||||
elif config_file_extension == '.json':
|
||||
self._reload_config_json(self.config_file_path)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized file type '{config_file_extension}' for sample config, must be .txt or .json")
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
|
||||
logging.warning(
|
||||
f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.")
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
def update_random_captions(self, possible_captions: list[str]):
|
||||
random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption]
|
||||
for i, r in enumerate(random_prompt_sample_requests):
|
||||
r.prompt = possible_captions[i % len(possible_captions)]
|
||||
|
||||
def _reload_sample_prompts_txt(self, path):
|
||||
with open(path, 'rt') as f:
|
||||
self.sample_requests = [SampleRequest(prompt=line.strip(),
|
||||
negative_prompt='',
|
||||
seed=self.default_seed,
|
||||
size=(self.default_resolution, self.default_resolution)
|
||||
) for line in f]
|
||||
if len(self.sample_requests) == 0:
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
def _make_random_caption_sample_requests(self):
|
||||
num_random_captions = min(4, self.batch_size)
|
||||
return [SampleRequest(prompt='',
|
||||
negative_prompt='',
|
||||
seed=self.default_seed,
|
||||
size=(self.default_resolution, self.default_resolution),
|
||||
wants_random_caption=True)
|
||||
for _ in range(num_random_captions)]
|
||||
|
||||
def _reload_config_json(self, path):
|
||||
with open(path, 'rt') as f:
|
||||
config = json.load(f)
|
||||
# if keys are missing, keep current values
|
||||
self.default_resolution = config.get('resolution', self.default_resolution)
|
||||
self.cfgs = config.get('cfgs', self.cfgs)
|
||||
self.batch_size = config.get('batch_size', self.batch_size)
|
||||
self.scheduler = config.get('scheduler', self.scheduler)
|
||||
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
|
||||
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
|
||||
sample_requests_json = config.get('samples', None)
|
||||
if sample_requests_json is None:
|
||||
self.sample_requests = []
|
||||
else:
|
||||
default_seed = config.get('seed', self.default_seed)
|
||||
default_size = (self.default_resolution, self.default_resolution)
|
||||
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
|
||||
negative_prompt=p.get('negative_prompt', ''),
|
||||
seed=p.get('seed', default_seed),
|
||||
size=tuple(p.get('size', default_size)),
|
||||
wants_random_caption=p.get('random_caption', False)
|
||||
) for p in sample_requests_json]
|
||||
if len(self.sample_requests) == 0:
|
||||
self._make_random_caption_sample_requests()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
||||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
|
||||
|
||||
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype(font="arial.ttf", size=20)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
sample_index = 0
|
||||
with autocast():
|
||||
batch: list[SampleRequest]
|
||||
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
|
||||
return a.size == b.size
|
||||
for batch in chunk_list(self.sample_requests, self.batch_size,
|
||||
compatibility_test=sample_compatibility_test):
|
||||
#print("batch: ", batch)
|
||||
prompts = [p.prompt for p in batch]
|
||||
negative_prompts = [p.negative_prompt for p in batch]
|
||||
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
|
||||
for p in batch]
|
||||
# all sizes in a batch are the same
|
||||
size = batch[0].size
|
||||
generators = [torch.Generator(pipe.device).manual_seed(seed) for seed in seeds]
|
||||
|
||||
batch_images = []
|
||||
for cfg in self.cfgs:
|
||||
images = pipe(prompt=prompts,
|
||||
negative_prompt=negative_prompts,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
num_images_per_prompt=1,
|
||||
guidance_scale=cfg,
|
||||
generator=generators,
|
||||
width=size[0],
|
||||
height=size[1],
|
||||
).images
|
||||
|
||||
for image in images:
|
||||
draw = ImageDraw.Draw(image)
|
||||
print_msg = f"cfg:{cfg:.1f}"
|
||||
|
||||
l, t, r, b = draw.textbbox(xy=(0, 0), text=print_msg, font=font)
|
||||
text_width = r - l
|
||||
text_height = b - t
|
||||
|
||||
x = float(image.width - text_width - 10)
|
||||
y = float(image.height - text_height - 10)
|
||||
|
||||
draw.rectangle((x, y, image.width, image.height), fill="white")
|
||||
draw.text((x, y), print_msg, fill="black", font=font)
|
||||
|
||||
batch_images.append(images)
|
||||
del images
|
||||
|
||||
del generators
|
||||
#print("batch_images:", batch_images)
|
||||
|
||||
width = size[0] * len(self.cfgs)
|
||||
height = size[1]
|
||||
|
||||
for prompt_idx in range(len(batch)):
|
||||
#print(f"batch_images[:][{prompt_idx}]: {batch_images[:][prompt_idx]}")
|
||||
result = Image.new('RGB', (width, height))
|
||||
x_offset = 0
|
||||
|
||||
for cfg_idx in range(len(self.cfgs)):
|
||||
image = batch_images[cfg_idx][prompt_idx]
|
||||
result.paste(image, (x_offset, 0))
|
||||
x_offset += image.width
|
||||
|
||||
prompt = prompts[prompt_idx]
|
||||
clean_prompt = clean_filename(prompt)
|
||||
|
||||
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
f.write(str(batch[prompt_idx]))
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
if batch[prompt_idx].wants_random_caption:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
|
||||
else:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
||||
sample_index += 1
|
||||
|
||||
del result
|
||||
del tfimage
|
||||
del batch_images
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
|
||||
"""
|
||||
creates a pipeline for SD inference
|
||||
"""
|
||||
scheduler = self._create_scheduler(diffusers_scheduler_config)
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be None if no safety checker
|
||||
)
|
||||
if self.use_xformers:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
return pipe
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _create_scheduler(self, scheduler_config: dict):
|
||||
scheduler = self.scheduler
|
||||
if scheduler not in ['ddim', 'dpm++']:
|
||||
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
|
||||
scheduler = 'ddim'
|
||||
|
||||
if scheduler == 'ddim':
|
||||
return DDIMScheduler.from_config(scheduler_config)
|
||||
elif scheduler == 'dpm++':
|
||||
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++")
|
||||
else:
|
||||
raise ValueError(f"unknown scheduler '{scheduler}'")
|
Loading…
Reference in New Issue