diff --git a/doc/LOGGING.md b/doc/LOGGING.md index fba41a0..441b673 100644 --- a/doc/LOGGING.md +++ b/doc/LOGGING.md @@ -6,15 +6,48 @@ Everydream2 uses the Tensorboard library to log performance metrics. (more opti You should launch tensorboard while your training is running and watch along. - tensorboard --logdir logs + tensorboard --logdir logs --samples_per_plugin images=100 ## Sample images -The trainer produces sample images from sample_prompts.txt with a fixed seed every so many steps as defined by your sample_steps argument. These are saved in the logs directory and can be viewed in tensorboard as well if you prefer. If you have a ton of them, the slider bar in tensorboard may not select them all, but the actual files are still stored in your logs as well for review. +By default, the trainer produces sample images from `sample_prompts.txt` with a fixed seed every so many steps as defined by your `sample_steps` argument. These are saved in the logs directory and can be viewed in tensorboard as well if you prefer. If you have a ton of them, the slider bar in tensorboard may not select them all (unless you launch tensorboard with the `--samples_per_plugin` argument as shown above), but the actual files are still stored in your logs as well for review. -Samples are produced at CFG scales of 1, 4, and 7. You may find this very useful to see how your model is progressing. +Samples are produced at CFG scales of 1, 4, and 7. You may find this very useful to see how your model is progressing. -If your sample_prompts.txt is empty, the trainer will generate from prompts from the last batch of your training data, up to 4 sets of samples. +If your `sample_prompts.txt` is empty, the trainer will generate from prompts from the last batch of your training data, up to 4 sets of samples. + +### More control + +In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, which offers more control over sample generation. Here is an example `sample_prompts.json` file: + +```json +{ + "batch_size": 3, + "seed": 555, + "cfgs": [7, 4], + "scheduler": "dpm++", + "num_inference_steps": 15, + "show_progress_bars": true, + "samples": [ + { + "prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background", + "negative_prompt": "distorted, deformed" + }, + { + "prompt": "a photograph of ted bennet riding a bicycle", + "seed": -1 + }, + { + "random_caption": true, + "size": [640, 384] + } + ] +} +``` + +At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process. + +Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64). Use `"random_caption": true` to pick a random caption from the training set each time. ## LR diff --git a/train.py b/train.py index 56f02ec..ca60266 100644 --- a/train.py +++ b/train.py @@ -31,17 +31,16 @@ import importlib import torch.nn.functional as F from torch.cuda.amp import autocast, GradScaler -import torchvision.transforms as transforms -from colorama import Fore, Style, Cursor +from colorama import Fore, Style import numpy as np import itertools import torch import datetime import json -from PIL import Image, ImageDraw, ImageFont -from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler +from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, \ + DPMSolverMultistepScheduler #from diffusers.models import AttentionBlock from diffusers.optimization import get_scheduler from diffusers.utils.import_utils import is_xformers_available @@ -62,16 +61,11 @@ if torch.cuda.is_available(): from utils.gpu import GPU import data.aspects as aspects import data.resolver as resolver +from utils.sample_generator import SampleGenerator _SIGTERM_EXIT_CODE = 130 _VERY_LARGE_NUMBER = 1e9 -def clean_filename(filename): - """ - removes all non-alphanumeric characters from a string so it is safe to use as a filename - """ - return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip() - def get_hf_ckpt_cache_path(ckpt_path): return os.path.join("ckpt_cache", os.path.basename(ckpt_path)) @@ -425,108 +419,6 @@ def main(args): # logging.info(f" Saving optimizer state to {save_path}") # self.save_optimizer(self.ctx.optimizer, optimizer_path) - @torch.no_grad() - def __create_inference_pipe(unet, text_encoder, tokenizer, scheduler, vae): - """ - creates a pipeline for SD inference - """ - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, # save vram - requires_safety_checker=None, # avoid nag - feature_extractor=None, # must be none of no safety checker - ) - - return pipe - - def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen): - """ - generates a single sample at a given cfg scale and saves it to disk - """ - with torch.no_grad(), autocast(): - image = pipe(prompt, - num_inference_steps=30, - num_images_per_prompt=1, - guidance_scale=cfg, - generator=gen, - height=resolution, - width=resolution, - ).images[0] - - draw = ImageDraw.Draw(image) - try: - font = ImageFont.truetype(font="arial.ttf", size=20) - except: - font = ImageFont.load_default() - print_msg = f"cfg:{cfg:.1f}" - - l, t, r, b = draw.textbbox(xy=(0,0), text=print_msg, font=font) - text_width = r - l - text_height = b - t - - x = float(image.width - text_width - 10) - y = float(image.height - text_height - 10) - - draw.rectangle((x, y, image.width, image.height), fill="white") - draw.text((x, y), print_msg, fill="black", font=font) - del draw, font - return image - - def __generate_test_samples(pipe, prompts, gs, log_writer, log_folder, random_captions=False, resolution=512): - """ - generates samples at different cfg scales and saves them to disk - """ - logging.info(f"Generating samples gs:{gs}, for {prompts}") - pipe.set_progress_bar_config(disable=True) - - seed = args.seed if args.seed != -1 else random.randint(0, 2**30) - gen = torch.Generator(device=device).manual_seed(seed) - - i = 0 - for prompt in prompts: - if prompt is None or len(prompt) < 2: - #logging.warning("empty prompt in sample prompts, check your prompts file") - continue - images = [] - for cfg in [7.0, 4.0, 1.01]: - image = __generate_sample(pipe, prompt, cfg, resolution=resolution, gen=gen) - images.append(image) - - width = 0 - height = 0 - for image in images: - width += image.width - height = max(height, image.height) - - result = Image.new('RGB', (width, height)) - - x_offset = 0 - for image in images: - result.paste(image, (x_offset, 0)) - x_offset += image.width - - clean_prompt = clean_filename(prompt) - - result.save(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False) - with open(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f: - f.write(prompt) - f.write(f"\n seed: {seed}") - - tfimage = transforms.ToTensor()(result) - if random_captions: - log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs) - else: - log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs) - i += 1 - - del result - del tfimage - del images - try: # check for a local file @@ -545,7 +437,7 @@ def main(args): text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet") - sample_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler") + reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False) except Exception as e: @@ -679,6 +571,11 @@ def main(args): log_args(log_writer, args) + sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer, + default_resolution=args.resolution, default_seed=args.seed, + config_file_path=args.sample_prompts, + batch_size=args.batch_size, + use_xformers=is_xformers_available() and not args.disable_xformers) """ Train the model @@ -831,6 +728,7 @@ def main(args): # # discard the grads, just want to pin memory # optimizer.zero_grad(set_to_none=True) + write_batch_schedule(args, log_folder, train_batch, 0) for epoch in range(args.max_epochs): @@ -900,19 +798,18 @@ def main(args): torch.cuda.empty_cache() if (global_step + 1) % args.sample_steps == 0: - pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae) - pipe = pipe.to(device) - with torch.no_grad(): - sample_prompts = read_sample_prompts(args.sample_prompts) - if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1: - __generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution) - else: - max_prompts = min(4,len(batch["captions"])) - prompts=batch["captions"][:max_prompts] - __generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution) + sample_generator.reload_config() + sample_generator.update_random_captions(batch["captions"]) + inference_pipe = sample_generator.create_inference_pipe(unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + diffusers_scheduler_config=reference_scheduler.config + ).to(device) + sample_generator.generate_samples(inference_pipe, global_step) - del pipe + del inference_pipe gc.collect() torch.cuda.empty_cache() @@ -966,8 +863,9 @@ def main(args): except Exception as ex: logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}") - save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}") - __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) + logging.error(f"{Fore.LIGHTYELLOW_EX} ^^ NO not doing that.{Style.RESET_ALL}") + #save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}") + #__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision) raise ex logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") @@ -1019,7 +917,7 @@ if __name__ == "__main__": argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'") argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions) argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ") - argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)") + argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="Text file with prompts to generate test samples from, or JSON file with sample generator settings (default: sample_prompts.txt)") argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)") argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)") diff --git a/utils/sample_generator.py b/utils/sample_generator.py new file mode 100644 index 0000000..449a86a --- /dev/null +++ b/utils/sample_generator.py @@ -0,0 +1,287 @@ +import json +import logging +import os.path +import traceback +from dataclasses import dataclass +import random +from typing import Generator, Callable, Any + +import torch +from PIL import Image, ImageDraw, ImageFont +from colorama import Fore, Style +from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler +from torch.cuda.amp import autocast +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms + + +def clean_filename(filename): + """ + removes all non-alphanumeric characters from a string so it is safe to use as a filename + """ + return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip() + +@dataclass +class SampleRequest: + prompt: str + negative_prompt: str + seed: int + size: tuple[int,int] + wants_random_caption: bool = False + + def __str__(self): + rep = self.prompt + if len(self.negative_prompt) > 0: + rep += "\n negative prompt: {self.negative_prompt}" + rep += f"\n seed: {self.seed}" + return rep + + +def chunk_list(l: list, batch_size: int, + compatibility_test: Callable[[Any,Any], bool]=lambda x,y: True + ) -> Generator[list, None, None]: + buckets = [] + for item in l: + compatible_bucket = next((b for b in buckets if compatibility_test(item, b[0])), None) + if compatible_bucket is not None: + compatible_bucket.append(item) + else: + buckets.append([item]) + + for b in buckets: + for i in range(0, len(b), batch_size): + yield b[i:i + batch_size] + + + + +class SampleGenerator: + seed: int + default_resolution: int + cfgs: list[float] = [7, 4, 1.01] + scheduler: str = 'ddim' + num_inference_steps: int = 30 + random_captions = False + + sample_requests: [str] + log_folder: str + log_writer: SummaryWriter + + def __init__(self, + log_folder: str, + log_writer: SummaryWriter, + default_resolution: int, + config_file_path: str, + batch_size: int, + default_seed: int, + use_xformers: bool): + self.log_folder = log_folder + self.log_writer = log_writer + self.batch_size = batch_size + self.config_file_path = config_file_path + self.use_xformers = use_xformers + self.show_progress_bars = False + + self.default_resolution = default_resolution + self.default_seed = default_seed + + self.reload_config() + print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps") + if not os.path.exists(f"{log_folder}/samples/"): + os.makedirs(f"{log_folder}/samples/") + + def reload_config(self): + try: + config_file_extension = os.path.splitext(self.config_file_path)[1].lower() + if config_file_extension == '.txt': + self._reload_sample_prompts_txt(self.config_file_path) + elif config_file_extension == '.json': + self._reload_config_json(self.config_file_path) + else: + raise ValueError(f"Unrecognized file type '{config_file_extension}' for sample config, must be .txt or .json") + except Exception as e: + logging.warning( + f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}") + logging.warning( + f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.") + self.sample_requests = self._make_random_caption_sample_requests() + + def update_random_captions(self, possible_captions: list[str]): + random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption] + for i, r in enumerate(random_prompt_sample_requests): + r.prompt = possible_captions[i % len(possible_captions)] + + def _reload_sample_prompts_txt(self, path): + with open(path, 'rt') as f: + self.sample_requests = [SampleRequest(prompt=line.strip(), + negative_prompt='', + seed=self.default_seed, + size=(self.default_resolution, self.default_resolution) + ) for line in f] + if len(self.sample_requests) == 0: + self.sample_requests = self._make_random_caption_sample_requests() + + def _make_random_caption_sample_requests(self): + num_random_captions = 4 + return [SampleRequest(prompt='', + negative_prompt='', + seed=self.default_seed, + size=(self.default_resolution, self.default_resolution), + wants_random_caption=True) + for _ in range(num_random_captions)] + + def _reload_config_json(self, path): + with open(path, 'rt') as f: + config = json.load(f) + # if keys are missing, keep current values + self.default_resolution = config.get('resolution', self.default_resolution) + self.cfgs = config.get('cfgs', self.cfgs) + self.batch_size = config.get('batch_size', self.batch_size) + self.scheduler = config.get('scheduler', self.scheduler) + self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps) + self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars) + sample_requests_json = config.get('samples', None) + if sample_requests_json is None: + self.sample_requests = self._make_random_caption_sample_requests() + else: + default_seed = config.get('seed', self.default_seed) + default_size = (self.default_resolution, self.default_resolution) + self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''), + negative_prompt=p.get('negative_prompt', ''), + seed=p.get('seed', default_seed), + size=tuple(p.get('size', default_size)), + wants_random_caption=p.get('random_caption', False) + ) for p in sample_requests_json] + + @torch.no_grad() + def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int): + """ + generates samples at different cfg scales and saves them to disk + """ + if len(self.sample_requests) == 0: + raise NotImplementedError("todo: implement random captions") + #max_prompts = min(4, len(batch["captions"])) + #sample_requests = batch["captions"][:max_prompts] + logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}") + + pipe.set_progress_bar_config(disable=(not self.show_progress_bars)) + + try: + font = ImageFont.truetype(font="arial.ttf", size=20) + except: + font = ImageFont.load_default() + + sample_index = 0 + with autocast(): + batch: list[SampleRequest] + def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool: + return a.size == b.size + for batch in chunk_list(self.sample_requests, self.batch_size, + compatibility_test=sample_compatibility_test): + #print("batch: ", batch) + prompts = [p.prompt for p in batch] + negative_prompts = [p.negative_prompt for p in batch] + seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30)) + for p in batch] + # all sizes in a batch are the same + size = batch[0].size + generators = [torch.Generator(pipe.device).manual_seed(seed) for seed in seeds] + + batch_images = [] + for cfg in self.cfgs: + images = pipe(prompt=prompts, + negative_prompt=negative_prompts, + num_inference_steps=self.num_inference_steps, + num_images_per_prompt=1, + guidance_scale=cfg, + generator=generators, + width=size[0], + height=size[1], + ).images + + for image in images: + draw = ImageDraw.Draw(image) + print_msg = f"cfg:{cfg:.1f}" + + l, t, r, b = draw.textbbox(xy=(0, 0), text=print_msg, font=font) + text_width = r - l + text_height = b - t + + x = float(image.width - text_width - 10) + y = float(image.height - text_height - 10) + + draw.rectangle((x, y, image.width, image.height), fill="white") + draw.text((x, y), print_msg, fill="black", font=font) + + batch_images.append(images) + del images + + del generators + #print("batch_images:", batch_images) + + width = size[0] * len(self.cfgs) + height = size[1] + + for prompt_idx in range(len(batch)): + #print(f"batch_images[:][{prompt_idx}]: {batch_images[:][prompt_idx]}") + result = Image.new('RGB', (width, height)) + x_offset = 0 + + for cfg_idx in range(len(self.cfgs)): + image = batch_images[cfg_idx][prompt_idx] + result.paste(image, (x_offset, 0)) + x_offset += image.width + + prompt = prompts[prompt_idx] + clean_prompt = clean_filename(prompt) + + result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False) + with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f: + f.write(str(batch[prompt_idx])) + + tfimage = transforms.ToTensor()(result) + if batch[prompt_idx].wants_random_caption: + self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step) + else: + self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step) + sample_index += 1 + + del result + del tfimage + del batch_images + + + @torch.no_grad() + def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict): + """ + creates a pipeline for SD inference + """ + scheduler = self._create_scheduler(diffusers_scheduler_config) + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, # save vram + requires_safety_checker=None, # avoid nag + feature_extractor=None, # must be None if no safety checker + ) + if self.use_xformers: + pipe.enable_xformers_memory_efficient_attention() + return pipe + + + @torch.no_grad() + def _create_scheduler(self, scheduler_config: dict): + scheduler = self.scheduler + if scheduler not in ['ddim', 'dpm++']: + print(f"unsupported scheduler '{self.scheduler}', falling back to ddim") + scheduler = 'ddim' + + if scheduler == 'ddim': + return DDIMScheduler.from_config(scheduler_config) + elif scheduler == 'dpm++': + return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++") + else: + raise ValueError(f"unknown scheduler '{scheduler}'")