Squashed commit of the following:

commit 86fa1363852850e87be11e5a277b71435f6a3451
Author: Damian Stewart <d@damianstewart.com>
Date:   Sat Feb 18 14:43:57 2023 +0100

    cleanup, add back random caption support

commit f9a10842b47b9a5d51d53de8d56cb7089a1eeeb2
Author: Damian Stewart <d@damianstewart.com>
Date:   Sat Feb 18 13:52:22 2023 +0100

    misc fixes and documentation

commit 46167806892258fef509f14e9d83ceab08725cd6
Author: Damian Stewart <d@damianstewart.com>
Date:   Sat Feb 18 12:11:18 2023 +0100

    works

commit 390bcdf4d8165315e2f84404c62b410c7b674c84
Author: Damian Stewart <d@damianstewart.com>
Date:   Sat Feb 18 10:12:14 2023 +0100

    SampleGenerator code in place (untested)

commit 022724fa7a435371081fd489ee7e5dbfc2df37ec
Author: Damian Stewart <d@damianstewart.com>
Date:   Sat Feb 18 10:17:05 2023 +0100

    cleanup and new approach (untested)

commit 4ac81f0924146a7ac3c46f4a4382e7dceaaac47c
Author: Damian Stewart <d@damianstewart.com>
Date:   Fri Jan 27 17:26:12 2023 +0100

    fix 'classmethod is not callable' error

commit c875933096464a867a5c3cfbf9592605f201f79e
Author: Damian Stewart <d@damianstewart.com>
Date:   Fri Jan 27 17:10:03 2023 +0100

    fix prompts log crash

commit 2771d52485191388dfa5b3b8892ed7327d874ed6
Author: Damian Stewart <d@damianstewart.com>
Date:   Fri Jan 27 14:38:39 2023 +0100

    fix circular import

commit 8452272b02fe64a2345fba067a55e51c52debd98
Author: Damian Stewart <d@damianstewart.com>
Date:   Fri Jan 27 14:33:26 2023 +0100

    refactor sample generation (untested)
This commit is contained in:
Damian Stewart 2023-02-18 15:51:50 +01:00
parent 6e7e7f9e1f
commit e97f0816db
3 changed files with 349 additions and 131 deletions

View File

@ -6,15 +6,48 @@ Everydream2 uses the Tensorboard library to log performance metrics. (more opti
You should launch tensorboard while your training is running and watch along.
tensorboard --logdir logs
tensorboard --logdir logs --samples_per_plugin images=100
## Sample images
The trainer produces sample images from sample_prompts.txt with a fixed seed every so many steps as defined by your sample_steps argument. These are saved in the logs directory and can be viewed in tensorboard as well if you prefer. If you have a ton of them, the slider bar in tensorboard may not select them all, but the actual files are still stored in your logs as well for review.
By default, the trainer produces sample images from `sample_prompts.txt` with a fixed seed every so many steps as defined by your `sample_steps` argument. These are saved in the logs directory and can be viewed in tensorboard as well if you prefer. If you have a ton of them, the slider bar in tensorboard may not select them all (unless you launch tensorboard with the `--samples_per_plugin` argument as shown above), but the actual files are still stored in your logs as well for review.
Samples are produced at CFG scales of 1, 4, and 7. You may find this very useful to see how your model is progressing.
Samples are produced at CFG scales of 1, 4, and 7. You may find this very useful to see how your model is progressing.
If your sample_prompts.txt is empty, the trainer will generate from prompts from the last batch of your training data, up to 4 sets of samples.
If your `sample_prompts.txt` is empty, the trainer will generate from prompts from the last batch of your training data, up to 4 sets of samples.
### More control
In place of `sample_prompts.txt` you can provide a `sample_prompts.json` file, which offers more control over sample generation. Here is an example `sample_prompts.json` file:
```json
{
"batch_size": 3,
"seed": 555,
"cfgs": [7, 4],
"scheduler": "dpm++",
"num_inference_steps": 15,
"show_progress_bars": true,
"samples": [
{
"prompt": "ted bennet and a man sitting on a sofa with a kitchen in the background",
"negative_prompt": "distorted, deformed"
},
{
"prompt": "a photograph of ted bennet riding a bicycle",
"seed": -1
},
{
"random_caption": true,
"size": [640, 384]
}
]
}
```
At the top you can set a `batch_size` (subject to VRAM limits), a default `seed` and `cfgs` to generate with, as well as a `scheduler` and `num_inference_steps` to control the quality of the samples. Available schedulers are `ddim` (the default) and `dpm++`. Finally, you can set `show_progress_bars` to `true` if you want to see progress bars during the sample generation process.
Individual samples are defined under the `samples` key. Each sample can have a `prompt`, a `negative_prompt`, a `seed` (use `-1` to pick a different random seed each time), and a `size` (must be multiples of 64). Use `"random_caption": true` to pick a random caption from the training set each time.
## LR

152
train.py
View File

@ -31,17 +31,16 @@ import importlib
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
from colorama import Fore, Style, Cursor
from colorama import Fore, Style
import numpy as np
import itertools
import torch
import datetime
import json
from PIL import Image, ImageDraw, ImageFont
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, \
DPMSolverMultistepScheduler
#from diffusers.models import AttentionBlock
from diffusers.optimization import get_scheduler
from diffusers.utils.import_utils import is_xformers_available
@ -62,16 +61,11 @@ if torch.cuda.is_available():
from utils.gpu import GPU
import data.aspects as aspects
import data.resolver as resolver
from utils.sample_generator import SampleGenerator
_SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9
def clean_filename(filename):
"""
removes all non-alphanumeric characters from a string so it is safe to use as a filename
"""
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
def get_hf_ckpt_cache_path(ckpt_path):
return os.path.join("ckpt_cache", os.path.basename(ckpt_path))
@ -425,108 +419,6 @@ def main(args):
# logging.info(f" Saving optimizer state to {save_path}")
# self.save_optimizer(self.ctx.optimizer, optimizer_path)
@torch.no_grad()
def __create_inference_pipe(unet, text_encoder, tokenizer, scheduler, vae):
"""
creates a pipeline for SD inference
"""
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
return pipe
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
with torch.no_grad(), autocast():
image = pipe(prompt,
num_inference_steps=30,
num_images_per_prompt=1,
guidance_scale=cfg,
generator=gen,
height=resolution,
width=resolution,
).images[0]
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype(font="arial.ttf", size=20)
except:
font = ImageFont.load_default()
print_msg = f"cfg:{cfg:.1f}"
l, t, r, b = draw.textbbox(xy=(0,0), text=print_msg, font=font)
text_width = r - l
text_height = b - t
x = float(image.width - text_width - 10)
y = float(image.height - text_height - 10)
draw.rectangle((x, y, image.width, image.height), fill="white")
draw.text((x, y), print_msg, fill="black", font=font)
del draw, font
return image
def __generate_test_samples(pipe, prompts, gs, log_writer, log_folder, random_captions=False, resolution=512):
"""
generates samples at different cfg scales and saves them to disk
"""
logging.info(f"Generating samples gs:{gs}, for {prompts}")
pipe.set_progress_bar_config(disable=True)
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
gen = torch.Generator(device=device).manual_seed(seed)
i = 0
for prompt in prompts:
if prompt is None or len(prompt) < 2:
#logging.warning("empty prompt in sample prompts, check your prompts file")
continue
images = []
for cfg in [7.0, 4.0, 1.01]:
image = __generate_sample(pipe, prompt, cfg, resolution=resolution, gen=gen)
images.append(image)
width = 0
height = 0
for image in images:
width += image.width
height = max(height, image.height)
result = Image.new('RGB', (width, height))
x_offset = 0
for image in images:
result.paste(image, (x_offset, 0))
x_offset += image.width
clean_prompt = clean_filename(prompt)
result.save(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{log_folder}/samples/gs{gs:05}-{i}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(prompt)
f.write(f"\n seed: {seed}")
tfimage = transforms.ToTensor()(result)
if random_captions:
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
else:
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
i += 1
del result
del tfimage
del images
try:
# check for a local file
@ -545,7 +437,7 @@ def main(args):
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
sample_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
except Exception as e:
@ -679,6 +571,11 @@ def main(args):
log_args(log_writer, args)
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
default_resolution=args.resolution, default_seed=args.seed,
config_file_path=args.sample_prompts,
batch_size=args.batch_size,
use_xformers=is_xformers_available() and not args.disable_xformers)
"""
Train the model
@ -831,6 +728,7 @@ def main(args):
# # discard the grads, just want to pin memory
# optimizer.zero_grad(set_to_none=True)
write_batch_schedule(args, log_folder, train_batch, 0)
for epoch in range(args.max_epochs):
@ -900,19 +798,18 @@ def main(args):
torch.cuda.empty_cache()
if (global_step + 1) % args.sample_steps == 0:
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device)
with torch.no_grad():
sample_prompts = read_sample_prompts(args.sample_prompts)
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1:
__generate_test_samples(pipe=pipe, prompts=sample_prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, resolution=args.resolution)
else:
max_prompts = min(4,len(batch["captions"]))
prompts=batch["captions"][:max_prompts]
__generate_test_samples(pipe=pipe, prompts=prompts, log_writer=log_writer, log_folder=log_folder, gs=global_step, random_captions=True, resolution=args.resolution)
sample_generator.reload_config()
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del pipe
del inference_pipe
gc.collect()
torch.cuda.empty_cache()
@ -966,8 +863,9 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
logging.error(f"{Fore.LIGHTYELLOW_EX} ^^ NO not doing that.{Style.RESET_ALL}")
#save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
#__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir, yaml, args.save_full_precision)
raise ex
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
@ -1019,7 +917,7 @@ if __name__ == "__main__":
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="Text file with prompts to generate test samples from, or JSON file with sample generator settings (default: sample_prompts.txt)")
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")

287
utils/sample_generator.py Normal file
View File

@ -0,0 +1,287 @@
import json
import logging
import os.path
import traceback
from dataclasses import dataclass
import random
from typing import Generator, Callable, Any
import torch
from PIL import Image, ImageDraw, ImageFont
from colorama import Fore, Style
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
def clean_filename(filename):
"""
removes all non-alphanumeric characters from a string so it is safe to use as a filename
"""
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
@dataclass
class SampleRequest:
prompt: str
negative_prompt: str
seed: int
size: tuple[int,int]
wants_random_caption: bool = False
def __str__(self):
rep = self.prompt
if len(self.negative_prompt) > 0:
rep += "\n negative prompt: {self.negative_prompt}"
rep += f"\n seed: {self.seed}"
return rep
def chunk_list(l: list, batch_size: int,
compatibility_test: Callable[[Any,Any], bool]=lambda x,y: True
) -> Generator[list, None, None]:
buckets = []
for item in l:
compatible_bucket = next((b for b in buckets if compatibility_test(item, b[0])), None)
if compatible_bucket is not None:
compatible_bucket.append(item)
else:
buckets.append([item])
for b in buckets:
for i in range(0, len(b), batch_size):
yield b[i:i + batch_size]
class SampleGenerator:
seed: int
default_resolution: int
cfgs: list[float] = [7, 4, 1.01]
scheduler: str = 'ddim'
num_inference_steps: int = 30
random_captions = False
sample_requests: [str]
log_folder: str
log_writer: SummaryWriter
def __init__(self,
log_folder: str,
log_writer: SummaryWriter,
default_resolution: int,
config_file_path: str,
batch_size: int,
default_seed: int,
use_xformers: bool):
self.log_folder = log_folder
self.log_writer = log_writer
self.batch_size = batch_size
self.config_file_path = config_file_path
self.use_xformers = use_xformers
self.show_progress_bars = False
self.default_resolution = default_resolution
self.default_seed = default_seed
self.reload_config()
print(f" * SampleGenerator initialized with {len(self.sample_requests)} prompts, using scheduler '{self.scheduler}', {self.num_inference_steps} steps")
if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/")
def reload_config(self):
try:
config_file_extension = os.path.splitext(self.config_file_path)[1].lower()
if config_file_extension == '.txt':
self._reload_sample_prompts_txt(self.config_file_path)
elif config_file_extension == '.json':
self._reload_config_json(self.config_file_path)
else:
raise ValueError(f"Unrecognized file type '{config_file_extension}' for sample config, must be .txt or .json")
except Exception as e:
logging.warning(
f" * {Fore.LIGHTYELLOW_EX}Error trying to read sample config from {self.config_file_path}: {Style.RESET_ALL}{e}")
logging.warning(
f" Using random caption samples until the problem is fixed. If you edit {self.config_file_path} to fix the problem, it will be automatically reloaded next time samples are due to be generated.")
self.sample_requests = self._make_random_caption_sample_requests()
def update_random_captions(self, possible_captions: list[str]):
random_prompt_sample_requests = [r for r in self.sample_requests if r.wants_random_caption]
for i, r in enumerate(random_prompt_sample_requests):
r.prompt = possible_captions[i % len(possible_captions)]
def _reload_sample_prompts_txt(self, path):
with open(path, 'rt') as f:
self.sample_requests = [SampleRequest(prompt=line.strip(),
negative_prompt='',
seed=self.default_seed,
size=(self.default_resolution, self.default_resolution)
) for line in f]
if len(self.sample_requests) == 0:
self.sample_requests = self._make_random_caption_sample_requests()
def _make_random_caption_sample_requests(self):
num_random_captions = 4
return [SampleRequest(prompt='',
negative_prompt='',
seed=self.default_seed,
size=(self.default_resolution, self.default_resolution),
wants_random_caption=True)
for _ in range(num_random_captions)]
def _reload_config_json(self, path):
with open(path, 'rt') as f:
config = json.load(f)
# if keys are missing, keep current values
self.default_resolution = config.get('resolution', self.default_resolution)
self.cfgs = config.get('cfgs', self.cfgs)
self.batch_size = config.get('batch_size', self.batch_size)
self.scheduler = config.get('scheduler', self.scheduler)
self.num_inference_steps = config.get('num_inference_steps', self.num_inference_steps)
self.show_progress_bars = config.get('show_progress_bars', self.show_progress_bars)
sample_requests_json = config.get('samples', None)
if sample_requests_json is None:
self.sample_requests = self._make_random_caption_sample_requests()
else:
default_seed = config.get('seed', self.default_seed)
default_size = (self.default_resolution, self.default_resolution)
self.sample_requests = [SampleRequest(prompt=p.get('prompt', ''),
negative_prompt=p.get('negative_prompt', ''),
seed=p.get('seed', default_seed),
size=tuple(p.get('size', default_size)),
wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_json]
@torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
"""
generates samples at different cfg scales and saves them to disk
"""
if len(self.sample_requests) == 0:
raise NotImplementedError("todo: implement random captions")
#max_prompts = min(4, len(batch["captions"]))
#sample_requests = batch["captions"][:max_prompts]
logging.info(f"Generating samples gs:{global_step}, for {[p.prompt for p in self.sample_requests]}")
pipe.set_progress_bar_config(disable=(not self.show_progress_bars))
try:
font = ImageFont.truetype(font="arial.ttf", size=20)
except:
font = ImageFont.load_default()
sample_index = 0
with autocast():
batch: list[SampleRequest]
def sample_compatibility_test(a: SampleRequest, b: SampleRequest) -> bool:
return a.size == b.size
for batch in chunk_list(self.sample_requests, self.batch_size,
compatibility_test=sample_compatibility_test):
#print("batch: ", batch)
prompts = [p.prompt for p in batch]
negative_prompts = [p.negative_prompt for p in batch]
seeds = [(p.seed if p.seed != -1 else random.randint(0, 2 ** 30))
for p in batch]
# all sizes in a batch are the same
size = batch[0].size
generators = [torch.Generator(pipe.device).manual_seed(seed) for seed in seeds]
batch_images = []
for cfg in self.cfgs:
images = pipe(prompt=prompts,
negative_prompt=negative_prompts,
num_inference_steps=self.num_inference_steps,
num_images_per_prompt=1,
guidance_scale=cfg,
generator=generators,
width=size[0],
height=size[1],
).images
for image in images:
draw = ImageDraw.Draw(image)
print_msg = f"cfg:{cfg:.1f}"
l, t, r, b = draw.textbbox(xy=(0, 0), text=print_msg, font=font)
text_width = r - l
text_height = b - t
x = float(image.width - text_width - 10)
y = float(image.height - text_height - 10)
draw.rectangle((x, y, image.width, image.height), fill="white")
draw.text((x, y), print_msg, fill="black", font=font)
batch_images.append(images)
del images
del generators
#print("batch_images:", batch_images)
width = size[0] * len(self.cfgs)
height = size[1]
for prompt_idx in range(len(batch)):
#print(f"batch_images[:][{prompt_idx}]: {batch_images[:][prompt_idx]}")
result = Image.new('RGB', (width, height))
x_offset = 0
for cfg_idx in range(len(self.cfgs)):
image = batch_images[cfg_idx][prompt_idx]
result.paste(image, (x_offset, 0))
x_offset += image.width
prompt = prompts[prompt_idx]
clean_prompt = clean_filename(prompt)
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result)
if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
else:
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
sample_index += 1
del result
del tfimage
del batch_images
@torch.no_grad()
def create_inference_pipe(self, unet, text_encoder, tokenizer, vae, diffusers_scheduler_config: dict):
"""
creates a pipeline for SD inference
"""
scheduler = self._create_scheduler(diffusers_scheduler_config)
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be None if no safety checker
)
if self.use_xformers:
pipe.enable_xformers_memory_efficient_attention()
return pipe
@torch.no_grad()
def _create_scheduler(self, scheduler_config: dict):
scheduler = self.scheduler
if scheduler not in ['ddim', 'dpm++']:
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
scheduler = 'ddim'
if scheduler == 'ddim':
return DDIMScheduler.from_config(scheduler_config)
elif scheduler == 'dpm++':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++")
else:
raise ValueError(f"unknown scheduler '{scheduler}'")