EveryDream2trainer/utils/sample_generator.py

323 lines
15 KiB
Python

import json
import logging
import os.path
from dataclasses import dataclass
import random
from typing import Generator, Callable, Any
import torch
from PIL import Image, ImageDraw, ImageFont
from colorama import Fore, Style
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm.auto import tqdm
def clean_filename(filename):
"""
removes all non-alphanumeric characters from a string so it is safe to use as a filename
"""
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
@dataclass
class SampleRequest:
prompt: str
negative_prompt: str
seed: int
size: tuple[int,int]
wants_random_caption: bool = False
def __str__(self):
rep = self.prompt
if len(self.negative_prompt) > 0:
rep += f"\n negative prompt: {self.negative_prompt}"
rep += f"\n seed: {self.seed}"
return rep
def chunk_list(l: list, batch_size: int,
compatibility_test: Callable[[Any,Any], bool]=lambda x,y: True
) -> Generator[list, None, None]:
buckets = []
for item in l:
compatible_bucket = next((b for b in buckets if compatibility_test(item, b[0])), None)
if compatible_bucket is not None:
compatible_bucket.append(item)
else:
buckets.append([item])
for b in buckets:
for i in range(0, len(b), batch_size):
yield b[i:i + batch_size]
def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]:
sizes = []
target_pixel_count = default_resolution * default_resolution
for w in range(256, 1024, 64):
for h in range(256, 1024, 64):
if abs((w * h) - target_pixel_count) <= 128 * 64:
sizes.append((w, h))
best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1]))))
return best_size
class SampleGenerator:
seed: int
default_resolution: int
cfgs: list[float] = [7, 4, 1.01]
scheduler: str = 'ddim'
num_inference_steps: int = 30
random_captions = False
sample_requests: [str]
log_folder: str
log_writer: SummaryWriter
def __init__(self,
log_folder: str,
log_writer: SummaryWriter,
default_resolution: int,
config_file_path: str,
batch_size: int,
default_seed: int,
default_sample_steps: int,
use_xformers: bool):
self.log_folder = log_folder
self.log_writer = log_writer
self.batch_size = batch_size
self.config_file_path = config_file_path
self.use_xformers = use_xformers
self.show_progress_bars = False
self.generate_pretrain_samples = False
self.default_resolution = default_resolution
self.default_seed = default_seed
self.sample_steps = default_sample_steps
self.sample_requests = None
self.reload_config()
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps")
if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/")
def reload_config(self):
try:
config_file_extension = os.path.splitext(self.config_file_path)[1].lower()
if config_file_extension == '.txt':
self._reload_sample_prompts_txt(self.config_file_path)
elif config_file_extension == '.json':
self._reload_config_json(self.config_file_path)
else:
raise ValueError(f"Unrecognized file type '{config_file_extension}' for sample config, must be .txt or .json")
except Exception as e:
logging.warning(
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
logging.warning(
f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated."
)
if self.sample_requests == None:
logging.warning(
f" Will generate samples from random training image captions until the problem is fixed.")
self.sample_requests = self._make_random_caption_sample_requests()
def update_random_captions(self, possible_captions: list[str]):
random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption]
for i, r in enumerate(random_prompt_sample_requests):
r.prompt = possible_captions[i % len(possible_captions)]
def _reload_sample_prompts_txt(self, path):
with open(path, 'rt') as f:
self.sample_requests = [SampleRequest(prompt=line.strip(),
negative_prompt='',
seed=self.default_seed,
size=(self.default_resolution, self.default_resolution)
) for line in f]
if len(self.sample_requests) == 0:
self.sample_requests = self._make_random_caption_sample_requests()
def _make_random_caption_sample_requests(self):
num_random_captions = min(4, self.batch_size)
return [SampleRequest(prompt='',
negative_prompt='',
seed=self.default_seed,
size=(self.default_resolution, self.default_resolution),
wants_random_caption=True)
for _ in range(num_random_captions)]
def _reload_config_json(self, path):
with open(path, 'rt') as f:
config = json.load(f)
# if keys are missing, keep current values
self.default_resolution = config.get('resolution', self.default_resolution)
self.cfgs = config.get('cfgs', self.cfgs)
self.batch_size = config.get('batch_size', self.batch_size)
self.scheduler = config.get('scheduler', self.scheduler)
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples)
self.sample_steps = config.get('generate_samples_every_n_steps', self.sample_steps)
sample_requests_config = config.get('samples', None)
if sample_requests_config is None:
self.sample_requests = self._make_random_caption_sample_requests()
else:
default_seed = config.get('seed', self.default_seed)
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
negative_prompt=p.get('negative_prompt', ''),
seed=p.get('seed', default_seed),
size=tuple(p.get('size', None) or
get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)),
wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_config]
if len(self.sample_requests) == 0:
self._make_random_caption_sample_requests()
@torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
"""
generates samples at different cfg scales and saves them to disk
"""
disable_progress_bars = not self.show_progress_bars
try:
font = ImageFont.truetype(font="arial.ttf", size=20)
except:
font = ImageFont.load_default()
if not self.show_progress_bars:
print(f" * Generating samples at gs:{global_step} for {len(self.sample_requests)} prompts")
sample_index = 0
with autocast():
batch: list[SampleRequest]
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
return a.size == b.size
batches = list(chunk_list(self.sample_requests, self.batch_size,
compatibility_test=sample_compatibility_test))
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
for batch in batches:
prompts = [p.prompt for p in batch]
negative_prompts = [p.negative_prompt for p in batch]
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
for p in batch]
# all sizes in a batch are the same
size = batch[0].size
generators = [torch.Generator(pipe.device).manual_seed(seed) for seed in seeds]
batch_images = []
for cfg in self.cfgs:
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
images = pipe(prompt=prompts,
negative_prompt=negative_prompts,
num_inference_steps=self.num_inference_steps,
num_images_per_prompt=1,
guidance_scale=cfg,
generator=generators,
width=size[0],
height=size[1],
).images
for image in images:
draw = ImageDraw.Draw(image)
print_msg = f"cfg:{cfg:.1f}"
l, t, r, b = draw.textbbox(xy=(0, 0), text=print_msg, font=font)
text_width = r - l
text_height = b - t
x = float(image.width - text_width - 10)
y = float(image.height - text_height - 10)
draw.rectangle((x, y, image.width, image.height), fill="white")
draw.text((x, y), print_msg, fill="black", font=font)
batch_images.append(images)
del images
del generators
#print("batch_images:", batch_images)
width = size[0] * len(self.cfgs)
height = size[1]
for prompt_idx in range(len(batch)):
#print(f"batch_images[:][{prompt_idx}]: {batch_images[:][prompt_idx]}")
result = Image.new('RGB', (width, height))
x_offset = 0
for cfg_idx in range(len(self.cfgs)):
image = batch_images[cfg_idx][prompt_idx]
result.paste(image, (x_offset, 0))
x_offset += image.width
prompt = prompts[prompt_idx]
clean_prompt = clean_filename(prompt)
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result)
if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
else:
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
sample_index += 1
del result
del tfimage
del batch_images
pbar.update(1)
@torch.no_grad()
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
"""
creates a pipeline for SD inference
"""
scheduler = self._create_scheduler(diffusers_scheduler_config)
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be None if no safety checker
)
if self.use_xformers:
pipe.enable_xformers_memory_efficient_attention()
return pipe
@torch.no_grad()
def _create_scheduler(self, scheduler_config: dict):
scheduler = self.scheduler
if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']:
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
scheduler = 'ddim'
if scheduler == 'ddim':
return DDIMScheduler.from_config(scheduler_config)
elif scheduler == 'dpm++':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++")
elif scheduler == 'pndm':
return PNDMScheduler.from_config(scheduler_config)
elif scheduler == 'ddpm':
return DDPMScheduler.from_config(scheduler_config)
elif scheduler == 'lms':
return LMSDiscreteScheduler.from_config(scheduler_config)
elif scheduler == 'euler':
return EulerDiscreteScheduler.from_config(scheduler_config)
elif scheduler == 'euler_a':
return EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif scheduler == 'kdpm2':
return KDPM2AncestralDiscreteScheduler.from_config(scheduler_config)
else:
raise ValueError(f"unknown scheduler '{scheduler}'")