2023-02-18 07:51:50 -07:00
import json
import logging
import os . path
from dataclasses import dataclass
import random
from typing import Generator , Callable , Any
import torch
from PIL import Image , ImageDraw , ImageFont
from colorama import Fore , Style
2023-02-27 17:14:23 -07:00
from diffusers import StableDiffusionPipeline , DDIMScheduler , DPMSolverMultistepScheduler , DDPMScheduler , PNDMScheduler , EulerDiscreteScheduler , EulerAncestralDiscreteScheduler , LMSDiscreteScheduler , KDPM2AncestralDiscreteScheduler
2023-02-18 07:51:50 -07:00
from torch . cuda . amp import autocast
from torch . utils . tensorboard import SummaryWriter
from torchvision import transforms
2023-03-02 10:29:28 -07:00
from tqdm . auto import tqdm
2023-02-18 07:51:50 -07:00
def clean_filename ( filename ) :
"""
removes all non - alphanumeric characters from a string so it is safe to use as a filename
"""
return " " . join ( [ c for c in filename if c . isalpha ( ) or c . isdigit ( ) or c == ' ' ] ) . rstrip ( )
@dataclass
class SampleRequest :
prompt : str
negative_prompt : str
seed : int
size : tuple [ int , int ]
wants_random_caption : bool = False
def __str__ ( self ) :
rep = self . prompt
if len ( self . negative_prompt ) > 0 :
2023-02-20 13:56:22 -07:00
rep + = f " \n negative prompt: { self . negative_prompt } "
2023-02-18 07:51:50 -07:00
rep + = f " \n seed: { self . seed } "
return rep
def chunk_list ( l : list , batch_size : int ,
compatibility_test : Callable [ [ Any , Any ] , bool ] = lambda x , y : True
) - > Generator [ list , None , None ] :
buckets = [ ]
for item in l :
compatible_bucket = next ( ( b for b in buckets if compatibility_test ( item , b [ 0 ] ) ) , None )
if compatible_bucket is not None :
compatible_bucket . append ( item )
else :
buckets . append ( [ item ] )
for b in buckets :
for i in range ( 0 , len ( b ) , batch_size ) :
yield b [ i : i + batch_size ]
2023-03-02 14:16:21 -07:00
def get_best_size_for_aspect_ratio ( aspect_ratio , default_resolution ) - > tuple [ int , int ] :
sizes = [ ]
target_pixel_count = default_resolution * default_resolution
for w in range ( 256 , 1024 , 64 ) :
for h in range ( 256 , 1024 , 64 ) :
if abs ( ( w * h ) - target_pixel_count ) < = 128 * 64 :
sizes . append ( ( w , h ) )
best_size = min ( sizes , key = lambda s : abs ( 1 - ( aspect_ratio / ( s [ 0 ] / s [ 1 ] ) ) ) )
return best_size
2023-02-18 07:51:50 -07:00
class SampleGenerator :
seed : int
default_resolution : int
cfgs : list [ float ] = [ 7 , 4 , 1.01 ]
scheduler : str = ' ddim '
num_inference_steps : int = 30
random_captions = False
sample_requests : [ str ]
log_folder : str
log_writer : SummaryWriter
def __init__ ( self ,
log_folder : str ,
log_writer : SummaryWriter ,
default_resolution : int ,
config_file_path : str ,
batch_size : int ,
default_seed : int ,
2023-03-02 05:03:50 -07:00
default_sample_steps : int ,
2023-02-18 07:51:50 -07:00
use_xformers : bool ) :
self . log_folder = log_folder
self . log_writer = log_writer
self . batch_size = batch_size
self . config_file_path = config_file_path
self . use_xformers = use_xformers
self . show_progress_bars = False
2023-03-02 05:03:50 -07:00
self . generate_pretrain_samples = False
2023-02-18 07:51:50 -07:00
self . default_resolution = default_resolution
self . default_seed = default_seed
2023-03-02 05:03:50 -07:00
self . sample_steps = default_sample_steps
2023-02-18 07:51:50 -07:00
2023-03-02 05:03:50 -07:00
self . sample_requests = None
2023-02-18 07:51:50 -07:00
self . reload_config ( )
2023-03-02 10:29:28 -07:00
print ( f " * SampleGenerator initialized with { len ( self . sample_requests ) } prompts, generating samples every { self . sample_steps } training steps, using scheduler ' { self . scheduler } ' with { self . num_inference_steps } inference steps " )
2023-02-18 07:51:50 -07:00
if not os . path . exists ( f " { log_folder } /samples/ " ) :
os . makedirs ( f " { log_folder } /samples/ " )
def reload_config ( self ) :
try :
config_file_extension = os . path . splitext ( self . config_file_path ) [ 1 ] . lower ( )
if config_file_extension == ' .txt ' :
self . _reload_sample_prompts_txt ( self . config_file_path )
elif config_file_extension == ' .json ' :
self . _reload_config_json ( self . config_file_path )
else :
raise ValueError ( f " Unrecognized file type ' { config_file_extension } ' for sample config, must be .txt or .json " )
except Exception as e :
logging . warning (
f " * { Fore . LIGHTYELLOW_EX } Error trying to read sample config from { self . config_file_path } : { Style . RESET_ALL } { e } " )
logging . warning (
2023-03-02 05:03:50 -07:00
f " Edit { self . config_file_path } to fix the problem. It will be automatically reloaded next time samples are due to be generated. "
)
if self . sample_requests == None :
logging . warning (
f " Will generate samples from random training image captions until the problem is fixed. " )
self . sample_requests = self . _make_random_caption_sample_requests ( )
2023-02-18 07:51:50 -07:00
def update_random_captions ( self , possible_captions : list [ str ] ) :
random_prompt_sample_requests = [ r for r in self . sample_requests if r . wants_random_caption ]
for i , r in enumerate ( random_prompt_sample_requests ) :
r . prompt = possible_captions [ i % len ( possible_captions ) ]
def _reload_sample_prompts_txt ( self , path ) :
with open ( path , ' rt ' ) as f :
self . sample_requests = [ SampleRequest ( prompt = line . strip ( ) ,
negative_prompt = ' ' ,
seed = self . default_seed ,
size = ( self . default_resolution , self . default_resolution )
) for line in f ]
if len ( self . sample_requests ) == 0 :
self . sample_requests = self . _make_random_caption_sample_requests ( )
def _make_random_caption_sample_requests ( self ) :
2023-02-18 07:54:41 -07:00
num_random_captions = min ( 4 , self . batch_size )
2023-02-18 07:51:50 -07:00
return [ SampleRequest ( prompt = ' ' ,
negative_prompt = ' ' ,
seed = self . default_seed ,
size = ( self . default_resolution , self . default_resolution ) ,
wants_random_caption = True )
for _ in range ( num_random_captions ) ]
def _reload_config_json ( self , path ) :
with open ( path , ' rt ' ) as f :
config = json . load ( f )
# if keys are missing, keep current values
self . default_resolution = config . get ( ' resolution ' , self . default_resolution )
self . cfgs = config . get ( ' cfgs ' , self . cfgs )
self . batch_size = config . get ( ' batch_size ' , self . batch_size )
self . scheduler = config . get ( ' scheduler ' , self . scheduler )
self . num_inference_steps = config . get ( ' num_inference_steps ' , self . num_inference_steps )
self . show_progress_bars = config . get ( ' show_progress_bars ' , self . show_progress_bars )
2023-03-02 05:03:50 -07:00
self . generate_pretrain_samples = config . get ( ' generate_pretrain_samples ' , self . generate_pretrain_samples )
2023-03-02 15:12:47 -07:00
self . sample_steps = config . get ( ' generate_samples_every_n_steps ' , self . sample_steps )
2023-03-02 05:03:50 -07:00
sample_requests_config = config . get ( ' samples ' , None )
if sample_requests_config is None :
self . sample_requests = self . _make_random_caption_sample_requests ( )
2023-02-18 07:51:50 -07:00
else :
default_seed = config . get ( ' seed ' , self . default_seed )
self . sample_requests = [ SampleRequest ( prompt = p . get ( ' prompt ' , ' ' ) ,
negative_prompt = p . get ( ' negative_prompt ' , ' ' ) ,
seed = p . get ( ' seed ' , default_seed ) ,
2023-03-02 14:16:21 -07:00
size = tuple ( p . get ( ' size ' , None ) or
get_best_size_for_aspect_ratio ( p . get ( ' aspect_ratio ' , 1 ) , self . default_resolution ) ) ,
2023-02-18 07:51:50 -07:00
wants_random_caption = p . get ( ' random_caption ' , False )
2023-03-02 05:03:50 -07:00
) for p in sample_requests_config ]
2023-02-18 11:28:08 -07:00
if len ( self . sample_requests ) == 0 :
self . _make_random_caption_sample_requests ( )
2023-02-18 07:51:50 -07:00
@torch.no_grad ( )
def generate_samples ( self , pipe : StableDiffusionPipeline , global_step : int ) :
"""
generates samples at different cfg scales and saves them to disk
"""
2023-03-02 10:29:28 -07:00
disable_progress_bars = not self . show_progress_bars
2023-02-18 07:51:50 -07:00
try :
font = ImageFont . truetype ( font = " arial.ttf " , size = 20 )
except :
font = ImageFont . load_default ( )
2023-03-03 02:50:48 -07:00
if not self . show_progress_bars :
print ( f " * Generating samples at gs: { global_step } for { len ( self . sample_requests ) } prompts " )
2023-02-18 07:51:50 -07:00
sample_index = 0
with autocast ( ) :
batch : list [ SampleRequest ]
def sample_compatibility_test ( a : SampleRequest , b : SampleRequest ) - > bool :
return a . size == b . size
2023-03-02 10:29:28 -07:00
batches = list ( chunk_list ( self . sample_requests , self . batch_size ,
compatibility_test = sample_compatibility_test ) )
pbar = tqdm ( total = len ( batches ) , disable = disable_progress_bars , position = 1 , leave = False ,
desc = f " { Fore . YELLOW } Image samples (batches of { self . batch_size } ) { Style . RESET_ALL } " )
for batch in batches :
2023-02-18 07:51:50 -07:00
prompts = [ p . prompt for p in batch ]
negative_prompts = [ p . negative_prompt for p in batch ]
seeds = [ ( p . seed if p . seed != - 1 else random . randint ( 0 , 2 * * 30 ) )
for p in batch ]
# all sizes in a batch are the same
size = batch [ 0 ] . size
generators = [ torch . Generator ( pipe . device ) . manual_seed ( seed ) for seed in seeds ]
batch_images = [ ]
for cfg in self . cfgs :
2023-03-02 10:29:28 -07:00
pipe . set_progress_bar_config ( disable = disable_progress_bars , position = 2 , leave = False ,
desc = f " { Fore . LIGHTYELLOW_EX } CFG scale { cfg } { Style . RESET_ALL } " )
2023-02-18 07:51:50 -07:00
images = pipe ( prompt = prompts ,
negative_prompt = negative_prompts ,
num_inference_steps = self . num_inference_steps ,
num_images_per_prompt = 1 ,
guidance_scale = cfg ,
generator = generators ,
width = size [ 0 ] ,
height = size [ 1 ] ,
) . images
for image in images :
draw = ImageDraw . Draw ( image )
print_msg = f " cfg: { cfg : .1f } "
l , t , r , b = draw . textbbox ( xy = ( 0 , 0 ) , text = print_msg , font = font )
text_width = r - l
text_height = b - t
x = float ( image . width - text_width - 10 )
y = float ( image . height - text_height - 10 )
draw . rectangle ( ( x , y , image . width , image . height ) , fill = " white " )
draw . text ( ( x , y ) , print_msg , fill = " black " , font = font )
batch_images . append ( images )
del images
del generators
#print("batch_images:", batch_images)
width = size [ 0 ] * len ( self . cfgs )
height = size [ 1 ]
for prompt_idx in range ( len ( batch ) ) :
#print(f"batch_images[:][{prompt_idx}]: {batch_images[:][prompt_idx]}")
result = Image . new ( ' RGB ' , ( width , height ) )
x_offset = 0
for cfg_idx in range ( len ( self . cfgs ) ) :
image = batch_images [ cfg_idx ] [ prompt_idx ]
result . paste ( image , ( x_offset , 0 ) )
x_offset + = image . width
prompt = prompts [ prompt_idx ]
clean_prompt = clean_filename ( prompt )
result . save ( f " { self . log_folder } /samples/gs { global_step : 05 } - { sample_index } - { clean_prompt [ : 100 ] } .jpg " , format = " JPEG " , quality = 95 , optimize = True , progressive = False )
with open ( f " { self . log_folder } /samples/gs { global_step : 05 } - { sample_index } - { clean_prompt [ : 100 ] } .txt " , " w " , encoding = ' utf-8 ' ) as f :
f . write ( str ( batch [ prompt_idx ] ) )
tfimage = transforms . ToTensor ( ) ( result )
if batch [ prompt_idx ] . wants_random_caption :
self . log_writer . add_image ( tag = f " sample_ { sample_index } " , img_tensor = tfimage , global_step = global_step )
else :
self . log_writer . add_image ( tag = f " sample_ { sample_index } _ { clean_prompt [ : 100 ] } " , img_tensor = tfimage , global_step = global_step )
sample_index + = 1
del result
del tfimage
del batch_images
2023-03-02 10:29:28 -07:00
pbar . update ( 1 )
2023-02-18 07:51:50 -07:00
@torch.no_grad ( )
def create_inference_pipe ( self , unet , text_encoder , tokenizer , vae , diffusers_scheduler_config : dict ) :
"""
creates a pipeline for SD inference
"""
scheduler = self . _create_scheduler ( diffusers_scheduler_config )
pipe = StableDiffusionPipeline (
vae = vae ,
text_encoder = text_encoder ,
tokenizer = tokenizer ,
unet = unet ,
scheduler = scheduler ,
safety_checker = None , # save vram
requires_safety_checker = None , # avoid nag
feature_extractor = None , # must be None if no safety checker
)
if self . use_xformers :
pipe . enable_xformers_memory_efficient_attention ( )
return pipe
@torch.no_grad ( )
def _create_scheduler ( self , scheduler_config : dict ) :
scheduler = self . scheduler
2023-02-27 17:14:23 -07:00
if scheduler not in [ ' ddim ' , ' dpm++ ' , ' pndm ' , ' ddpm ' , ' lms ' , ' euler ' , ' euler_a ' , ' kdpm2 ' ] :
2023-02-18 07:51:50 -07:00
print ( f " unsupported scheduler ' { self . scheduler } ' , falling back to ddim " )
scheduler = ' ddim '
if scheduler == ' ddim ' :
return DDIMScheduler . from_config ( scheduler_config )
elif scheduler == ' dpm++ ' :
return DPMSolverMultistepScheduler . from_config ( scheduler_config , algorithm_type = " dpmsolver++ " )
2023-02-27 17:14:23 -07:00
elif scheduler == ' pndm ' :
return PNDMScheduler . from_config ( scheduler_config )
elif scheduler == ' ddpm ' :
return DDPMScheduler . from_config ( scheduler_config )
elif scheduler == ' lms ' :
return LMSDiscreteScheduler . from_config ( scheduler_config )
elif scheduler == ' euler ' :
return EulerDiscreteScheduler . from_config ( scheduler_config )
elif scheduler == ' euler_a ' :
return EulerAncestralDiscreteScheduler . from_config ( scheduler_config )
elif scheduler == ' kdpm2 ' :
return KDPM2AncestralDiscreteScheduler . from_config ( scheduler_config )
2023-02-18 07:51:50 -07:00
else :
raise ValueError ( f " unknown scheduler ' { scheduler } ' " )