import json import logging import os.path from dataclasses import dataclass import random from typing import Generator, Callable, Any import torch from PIL import Image, ImageDraw, ImageFont from colorama import Fore, Style from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler from torch.cuda.amp import autocast from torch.utils.tensorboard import SummaryWriter from torchvision import transforms from tqdm.auto import tqdm def clean_filename(filename): """ removes all non-alphanumeric characters from a string so it is safe to use as a filename """ return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip() @dataclass class SampleRequest: prompt: str negative_prompt: str seed: int size: tuple[int,int] wants_random_caption: bool = False def __str__(self): rep = self.prompt if len(self.negative_prompt) > 0: rep += f"\n negative prompt: {self.negative_prompt}" rep += f"\n seed: {self.seed}" return rep def chunk_list(l: list, batch_size: int, compatibility_test: Callable[[Any,Any], bool]=lambda x,y: True ) -> Generator[list, None, None]: buckets = [] for item in l: compatible_bucket = next((b for b in buckets if compatibility_test(item, b[0])), None) if compatible_bucket is not None: compatible_bucket.append(item) else: buckets.append([item]) for b in buckets: for i in range(0, len(b), batch_size): yield b[i:i + batch_size] def get_best_size_for_aspect_ratio(aspect_ratio, default_resolution) -> tuple[int, int]: sizes = [] target_pixel_count = default_resolution * default_resolution for w in range(256, 1024, 64): for h in range(256, 1024, 64): if abs((w * h) - target_pixel_count) <= 128 * 64: sizes.append((w, h)) best_size = min(sizes, key=lambda s: abs(1 - (aspect_ratio / (s[0] / s[1])))) return best_size class SampleGenerator: seed: int default_resolution: int cfgs: list[float] = [7, 4, 1.01] scheduler: str = 'ddim' num_inference_steps: int = 30 random_captions = False sample_requests: [str] log_folder: str log_writer: SummaryWriter def __init__(self, log_folder: str, log_writer: SummaryWriter, default_resolution: int, config_file_path: str, batch_size: int, default_seed: int, default_sample_steps: int, use_xformers: bool): self.log_folder = log_folder self.log_writer = log_writer self.batch_size = batch_size self.config_file_path = config_file_path self.use_xformers = use_xformers self.show_progress_bars = False self.generate_pretrain_samples = False self.default_resolution = default_resolution self.default_seed = default_seed self.sample_steps = default_sample_steps self.sample_requests = None self.reload_config() print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, generating samples every {self.sample_steps} training steps, using scheduler '{self.scheduler}' with {self.num_inference_steps} inference steps") if not os.path.exists(f"{log_folder}/samples/"): os.makedirs(f"{log_folder}/samples/") def reload_config(self): try: config_file_extension = os.path.splitext(self.config_file_path)[1].lower() if config_file_extension == '.txt': self._reload_sample_prompts_txt(self.config_file_path) elif config_file_extension == '.json': self._reload_config_json(self.config_file_path) else: raise ValueError(f"Unrecognized file type '{config_file_extension}' for sample config, must be .txt or .json") except Exception as e: logging.warning( f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}") logging.warning( f" Edit {self.config_file_path} to fix the problem. It will be automatically reloaded next time samples are due to be generated." ) if self.sample_requests == None: logging.warning( f" Will generate samples from random training image captions until the problem is fixed.") self.sample_requests = self._make_random_caption_sample_requests() def update_random_captions(self, possible_captions: list[str]): random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption] for i, r in enumerate(random_prompt_sample_requests): r.prompt = possible_captions[i % len(possible_captions)] def _reload_sample_prompts_txt(self, path): with open(path, 'rt') as f: self.sample_requests = [SampleRequest(prompt=line.strip(), negative_prompt='', seed=self.default_seed, size=(self.default_resolution, self.default_resolution) ) for line in f] if len(self.sample_requests) == 0: self.sample_requests = self._make_random_caption_sample_requests() def _make_random_caption_sample_requests(self): num_random_captions = min(4, self.batch_size) return [SampleRequest(prompt='', negative_prompt='', seed=self.default_seed, size=(self.default_resolution, self.default_resolution), wants_random_caption=True) for _ in range(num_random_captions)] def _reload_config_json(self, path): with open(path, 'rt') as f: config = json.load(f) # if keys are missing, keep current values self.default_resolution = config.get('resolution', self.default_resolution) self.cfgs = config.get('cfgs', self.cfgs) self.batch_size = config.get('batch_size', self.batch_size) self.scheduler = config.get('scheduler', self.scheduler) self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps) self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars) self.generate_pretrain_samples = config.get('generate_pretrain_samples', self.generate_pretrain_samples) self.sample_steps = config.get('generate_samples_every_n_steps', self.sample_steps) sample_requests_config = config.get('samples', None) if sample_requests_config is None: self.sample_requests = self._make_random_caption_sample_requests() else: default_seed = config.get('seed', self.default_seed) self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''), negative_prompt=p.get('negative_prompt', ''), seed=p.get('seed', default_seed), size=tuple(p.get('size', None) or get_best_size_for_aspect_ratio(p.get('aspect_ratio', 1), self.default_resolution)), wants_random_caption=p.get('random_caption', False) ) for p in sample_requests_config] if len(self.sample_requests) == 0: self._make_random_caption_sample_requests() @torch.no_grad() def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int): """ generates samples at different cfg scales and saves them to disk """ disable_progress_bars = not self.show_progress_bars try: font = ImageFont.truetype(font="arial.ttf", size=20) except: font = ImageFont.load_default() if not self.show_progress_bars: print(f" * Generating samples at gs:{global_step} for {len(self.sample_requests)} prompts") sample_index = 0 with autocast(): batch: list[SampleRequest] def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool: return a.size == b.size batches = list(chunk_list(self.sample_requests, self.batch_size, compatibility_test=sample_compatibility_test)) pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False, desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}") for batch in batches: prompts = [p.prompt for p in batch] negative_prompts = [p.negative_prompt for p in batch] seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) for p in batch] # all sizes in a batch are the same size = batch[0].size generators = [torch.Generator(pipe.device).manual_seed(seed) for seed in seeds] batch_images = [] for cfg in self.cfgs: pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False, desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}") images = pipe(prompt=prompts, negative_prompt=negative_prompts, num_inference_steps=self.num_inference_steps, num_images_per_prompt=1, guidance_scale=cfg, generator=generators, width=size[0], height=size[1], ).images for image in images: draw = ImageDraw.Draw(image) print_msg = f"cfg:{cfg:.1f}" l, t, r, b = draw.textbbox(xy=(0, 0), text=print_msg, font=font) text_width = r - l text_height = b - t x = float(image.width - text_width - 10) y = float(image.height - text_height - 10) draw.rectangle((x, y, image.width, image.height), fill="white") draw.text((x, y), print_msg, fill="black", font=font) batch_images.append(images) del images del generators #print("batch_images:", batch_images) width = size[0] * len(self.cfgs) height = size[1] for prompt_idx in range(len(batch)): #print(f"batch_images[:][{prompt_idx}]: {batch_images[:][prompt_idx]}") result = Image.new('RGB', (width, height)) x_offset = 0 for cfg_idx in range(len(self.cfgs)): image = batch_images[cfg_idx][prompt_idx] result.paste(image, (x_offset, 0)) x_offset += image.width prompt = prompts[prompt_idx] clean_prompt = clean_filename(prompt) result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False) with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f: f.write(str(batch[prompt_idx])) tfimage = transforms.ToTensor()(result) if batch[prompt_idx].wants_random_caption: self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step) else: self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step) sample_index += 1 del result del tfimage del batch_images pbar.update(1) @torch.no_grad() def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict): """ creates a pipeline for SD inference """ scheduler = self._create_scheduler(diffusers_scheduler_config) pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, # save vram requires_safety_checker=None, # avoid nag feature_extractor=None, # must be None if no safety checker ) if self.use_xformers: pipe.enable_xformers_memory_efficient_attention() return pipe @torch.no_grad() def _create_scheduler(self, scheduler_config: dict): scheduler = self.scheduler if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']: print(f"unsupported scheduler '{self.scheduler}', falling back to ddim") scheduler = 'ddim' if scheduler == 'ddim': return DDIMScheduler.from_config(scheduler_config) elif scheduler == 'dpm++': return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++") elif scheduler == 'pndm': return PNDMScheduler.from_config(scheduler_config) elif scheduler == 'ddpm': return DDPMScheduler.from_config(scheduler_config) elif scheduler == 'lms': return LMSDiscreteScheduler.from_config(scheduler_config) elif scheduler == 'euler': return EulerDiscreteScheduler.from_config(scheduler_config) elif scheduler == 'euler_a': return EulerAncestralDiscreteScheduler.from_config(scheduler_config) elif scheduler == 'kdpm2': return KDPM2AncestralDiscreteScheduler.from_config(scheduler_config) else: raise ValueError(f"unknown scheduler '{scheduler}'")