2022-11-01 14:28:12 -06:00
# Example Usage:
# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
2022-10-29 10:26:03 -06:00
import argparse
import socket
import torch
import torchvision
import transformers
import diffusers
import os
import glob
import random
import tqdm
import resource
import psutil
import pynvml
import wandb
import gc
2022-11-01 14:28:12 -06:00
import time
2022-10-29 10:26:03 -06:00
import itertools
import numpy as np
2022-11-05 07:55:56 -06:00
import json
2022-10-29 10:26:03 -06:00
try :
pynvml . nvmlInit ( )
except pynvml . nvml . NVMLError_LibraryNotFound :
pynvml = None
from typing import Iterable
from diffusers import AutoencoderKL , UNet2DConditionModel , DDPMScheduler , PNDMScheduler , StableDiffusionPipeline
from diffusers . pipelines . stable_diffusion import StableDiffusionSafetyChecker
from diffusers . optimization import get_scheduler
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
from PIL import Image
2022-11-05 07:55:56 -06:00
from PIL . Image import Resampling
2022-10-29 10:26:03 -06:00
from typing import Dict , List , Generator , Tuple
from scipy . interpolate import interp1d
2022-11-01 14:28:12 -06:00
torch . backends . cuda . matmul . allow_tf32 = True
2022-10-29 10:26:03 -06:00
# defaults should be good for everyone
# TODO: add custom VAE support. should be simple with diffusers
parser = argparse . ArgumentParser ( description = ' Stable Diffusion Finetuner ' )
parser . add_argument ( ' --model ' , type = str , default = None , required = True , help = ' The name of the model to use for finetuning. Could be HuggingFace ID or a directory ' )
parser . add_argument ( ' --run_name ' , type = str , default = None , required = True , help = ' Name of the finetune run. ' )
parser . add_argument ( ' --dataset ' , type = str , default = None , required = True , help = ' The path to the dataset to use for finetuning. ' )
2022-11-05 07:55:56 -06:00
parser . add_argument ( ' --num_buckets ' , type = int , default = 16 , help = ' The number of buckets. ' )
2022-10-29 10:26:03 -06:00
parser . add_argument ( ' --bucket_side_min ' , type = int , default = 256 , help = ' The minimum side length of a bucket. ' )
parser . add_argument ( ' --bucket_side_max ' , type = int , default = 768 , help = ' The maximum side length of a bucket. ' )
parser . add_argument ( ' --lr ' , type = float , default = 5e-6 , help = ' Learning rate ' )
parser . add_argument ( ' --epochs ' , type = int , default = 10 , help = ' Number of epochs to train for ' )
parser . add_argument ( ' --batch_size ' , type = int , default = 1 , help = ' Batch size ' )
parser . add_argument ( ' --use_ema ' , type = bool , default = False , help = ' Use EMA for finetuning ' )
parser . add_argument ( ' --ucg ' , type = float , default = 0.1 , help = ' Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100 % text condition dropout. ' ) # 10% dropout probability
parser . add_argument ( ' --gradient_checkpointing ' , dest = ' gradient_checkpointing ' , type = bool , default = False , help = ' Enable gradient checkpointing ' )
parser . add_argument ( ' --use_8bit_adam ' , dest = ' use_8bit_adam ' , type = bool , default = False , help = ' Use 8-bit Adam optimizer ' )
parser . add_argument ( ' --adam_beta1 ' , type = float , default = 0.9 , help = ' Adam beta1 ' )
parser . add_argument ( ' --adam_beta2 ' , type = float , default = 0.999 , help = ' Adam beta2 ' )
parser . add_argument ( ' --adam_weight_decay ' , type = float , default = 1e-2 , help = ' Adam weight decay ' )
parser . add_argument ( ' --adam_epsilon ' , type = float , default = 1e-08 , help = ' Adam epsilon ' )
parser . add_argument ( ' --seed ' , type = int , default = 42 , help = ' Seed for random number generator, this is to be used for reproduceability purposes. ' )
parser . add_argument ( ' --output_path ' , type = str , default = ' ./output ' , help = ' Root path for all outputs. ' )
parser . add_argument ( ' --save_steps ' , type = int , default = 500 , help = ' Number of steps to save checkpoints at. ' )
parser . add_argument ( ' --resolution ' , type = int , default = 512 , help = ' Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down. ' )
parser . add_argument ( ' --shuffle ' , dest = ' shuffle ' , type = bool , default = True , help = ' Shuffle dataset ' )
parser . add_argument ( ' --hf_token ' , type = str , default = None , required = False , help = ' A HuggingFace token is needed to download private models for training. ' )
parser . add_argument ( ' --project_id ' , type = str , default = ' diffusers ' , help = ' Project ID for reporting to WandB ' )
parser . add_argument ( ' --fp16 ' , dest = ' fp16 ' , type = bool , default = False , help = ' Train in mixed precision ' )
parser . add_argument ( ' --image_log_steps ' , type = int , default = 100 , help = ' Number of steps to log images at. ' )
parser . add_argument ( ' --image_log_amount ' , type = int , default = 4 , help = ' Number of images to log every image_log_steps ' )
2022-11-02 17:20:46 -06:00
parser . add_argument ( ' --clip_penultimate ' , type = bool , default = False , help = ' Use penultimate CLIP layer for text embedding ' )
2022-11-05 07:55:56 -06:00
parser . add_argument ( ' --output_bucket_info ' , type = bool , default = False , help = ' Outputs bucket information and exits ' )
2022-10-29 10:26:03 -06:00
args = parser . parse_args ( )
2022-11-01 14:28:12 -06:00
def setup ( ) :
torch . distributed . init_process_group ( " nccl " , init_method = " env:// " )
def cleanup ( ) :
torch . distributed . destroy_process_group ( )
2022-10-29 10:26:03 -06:00
2022-11-01 14:28:12 -06:00
def get_rank ( ) - > int :
if not torch . distributed . is_initialized ( ) :
return 0
return torch . distributed . get_rank ( )
def get_world_size ( ) - > int :
if not torch . distributed . is_initialized ( ) :
return 1
return torch . distributed . get_world_size ( )
2022-10-29 10:26:03 -06:00
# Inform the user of host, and various versions -- useful for debugging isseus.
print ( " RUN_NAME: " , args . run_name )
print ( " HOST: " , socket . gethostname ( ) )
print ( " CUDA: " , torch . version . cuda )
print ( " TORCH: " , torch . __version__ )
print ( " TRANSFORMERS: " , transformers . __version__ )
print ( " DIFFUSERS: " , diffusers . __version__ )
print ( " MODEL: " , args . model )
print ( " FP16: " , args . fp16 )
print ( " RESOLUTION: " , args . resolution )
def get_gpu_ram ( ) - > str :
"""
Returns memory usage statistics for the CPU , GPU , and Torch .
: return :
"""
gpu_str = " "
torch_str = " "
try :
cudadev = torch . cuda . current_device ( )
nvml_device = pynvml . nvmlDeviceGetHandleByIndex ( cudadev )
gpu_info = pynvml . nvmlDeviceGetMemoryInfo ( nvml_device )
gpu_total = int ( gpu_info . total / 1E6 )
gpu_free = int ( gpu_info . free / 1E6 )
gpu_used = int ( gpu_info . used / 1E6 )
gpu_str = f " GPU: (U: { gpu_used : , } mb F: { gpu_free : , } mb " \
f " T: { gpu_total : , } mb) "
torch_reserved_gpu = int ( torch . cuda . memory . memory_reserved ( ) / 1E6 )
torch_reserved_max = int ( torch . cuda . memory . max_memory_reserved ( ) / 1E6 )
torch_used_gpu = int ( torch . cuda . memory_allocated ( ) / 1E6 )
torch_max_used_gpu = int ( torch . cuda . max_memory_allocated ( ) / 1E6 )
torch_str = f " TORCH: (R: { torch_reserved_gpu : , } mb/ " \
f " { torch_reserved_max : , } mb, " \
f " A: { torch_used_gpu : , } mb/ { torch_max_used_gpu : , } mb) "
except AssertionError :
pass
cpu_maxrss = int ( resource . getrusage ( resource . RUSAGE_SELF ) . ru_maxrss / 1E3 +
resource . getrusage (
resource . RUSAGE_CHILDREN ) . ru_maxrss / 1E3 )
cpu_vmem = psutil . virtual_memory ( )
cpu_free = int ( cpu_vmem . free / 1E6 )
return f " CPU: (maxrss: { cpu_maxrss : , } mb F: { cpu_free : , } mb) " \
f " { gpu_str } " \
f " { torch_str } "
def _sort_by_ratio ( bucket : tuple ) - > float :
return bucket [ 0 ] / bucket [ 1 ]
def _sort_by_area ( bucket : tuple ) - > float :
return bucket [ 0 ] * bucket [ 1 ]
class ImageStore :
def __init__ ( self , data_dir : str ) - > None :
self . data_dir = data_dir
self . image_files = [ ]
[ self . image_files . extend ( glob . glob ( f ' { data_dir } ' + ' /*. ' + e ) ) for e in [ ' jpg ' , ' jpeg ' , ' png ' , ' bmp ' , ' webp ' ] ]
2022-11-05 07:55:56 -06:00
self . image_files = [ x for x in self . image_files if self . __valid_file ( x ) ]
2022-10-29 10:26:03 -06:00
def __len__ ( self ) - > int :
return len ( self . image_files )
2022-11-05 07:55:56 -06:00
def __valid_file ( self , f ) - > bool :
try :
Image . open ( f )
return True
except :
print ( f ' WARNING: Unable to open file: { f } ' )
return False
2022-10-29 10:26:03 -06:00
# iterator returns images as PIL images and their index in the store
def entries_iterator ( self ) - > Generator [ Tuple [ Image . Image , int ] , None , None ] :
for f in range ( len ( self ) ) :
yield Image . open ( self . image_files [ f ] ) , f
2022-11-05 07:55:56 -06:00
2022-10-29 10:26:03 -06:00
# get image by index
def get_image ( self , index : int ) - > Image . Image :
return Image . open ( self . image_files [ index ] )
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption ( self , index : int ) - > str :
filename = self . image_files [ index ] . split ( ' . ' ) [ 0 ] + ' .txt '
with open ( filename , ' r ' ) as f :
return f . read ( )
2022-10-29 17:11:48 -06:00
# ====================================== #
# Bucketing code stolen from hasuwoof: #
# https://github.com/hasuwoof/huskystack #
# ====================================== #
2022-10-29 10:26:03 -06:00
class AspectBucket :
def __init__ ( self , store : ImageStore ,
num_buckets : int ,
batch_size : int ,
bucket_side_min : int = 256 ,
bucket_side_max : int = 768 ,
bucket_side_increment : int = 64 ,
max_image_area : int = 512 * 768 ,
max_ratio : float = 2 ) :
2022-11-05 07:55:56 -06:00
2022-10-29 10:26:03 -06:00
self . requested_bucket_count = num_buckets
self . bucket_length_min = bucket_side_min
self . bucket_length_max = bucket_side_max
self . bucket_increment = bucket_side_increment
self . max_image_area = max_image_area
self . batch_size = batch_size
self . total_dropped = 0
if max_ratio < = 0 :
self . max_ratio = float ( ' inf ' )
else :
self . max_ratio = max_ratio
2022-11-05 07:55:56 -06:00
2022-10-29 10:26:03 -06:00
self . store = store
self . buckets = [ ]
self . _bucket_ratios = [ ]
self . _bucket_interp = None
self . bucket_data : Dict [ tuple , List [ int ] ] = dict ( )
self . init_buckets ( )
self . fill_buckets ( )
2022-11-05 07:55:56 -06:00
2022-10-29 10:26:03 -06:00
def init_buckets ( self ) :
possible_lengths = list ( range ( self . bucket_length_min , self . bucket_length_max + 1 , self . bucket_increment ) )
possible_buckets = list ( ( w , h ) for w , h in itertools . product ( possible_lengths , possible_lengths )
if w > = h and w * h < = self . max_image_area and w / h < = self . max_ratio )
buckets_by_ratio = { }
# group the buckets by their aspect ratios
for bucket in possible_buckets :
w , h = bucket
# use precision to avoid spooky floats messing up your day
ratio = ' {:.4e} ' . format ( w / h )
if ratio not in buckets_by_ratio :
group = set ( )
buckets_by_ratio [ ratio ] = group
else :
group = buckets_by_ratio [ ratio ]
group . add ( bucket )
# now we take the list of buckets we generated and pick the largest by area for each (the first sorted)
# then we put all of those in a list, sorted by the aspect ratio
# the square bucket (LxL) will be the first
unique_ratio_buckets = sorted ( [ sorted ( buckets , key = _sort_by_area ) [ - 1 ]
for buckets in buckets_by_ratio . values ( ) ] , key = _sort_by_ratio )
# how many buckets to create for each side of the distribution
bucket_count_each = int ( np . clip ( ( self . requested_bucket_count + 1 ) / 2 , 1 , len ( unique_ratio_buckets ) ) )
# we know that the requested_bucket_count must be an odd number, so the indices we calculate
# will include the square bucket and some linearly spaced buckets along the distribution
indices = { * np . linspace ( 0 , len ( unique_ratio_buckets ) - 1 , bucket_count_each , dtype = int ) }
# make the buckets, make sure they are unique (to remove the duplicated square bucket), and sort them by ratio
# here we add the portrait buckets by reversing the dimensions of the landscape buckets we generated above
buckets = sorted ( { * ( unique_ratio_buckets [ i ] for i in indices ) ,
* ( tuple ( reversed ( unique_ratio_buckets [ i ] ) ) for i in indices ) } , key = _sort_by_ratio )
self . buckets = buckets
# cache the bucket ratios and the interpolator that will be used for calculating the best bucket later
# the interpolator makes a 1d piecewise interpolation where the input (x-axis) is the bucket ratio,
# and the output is the bucket index in the self.buckets array
# to find the best fit we can just round that number to get the index
self . _bucket_ratios = [ w / h for w , h in buckets ]
self . _bucket_interp = interp1d ( self . _bucket_ratios , list ( range ( len ( buckets ) ) ) , assume_sorted = True ,
fill_value = None )
for b in buckets :
self . bucket_data [ b ] = [ ]
def get_batch_count ( self ) :
return sum ( len ( b ) / / self . batch_size for b in self . bucket_data . values ( ) )
2022-11-05 07:55:56 -06:00
def get_bucket_info ( self ) :
return json . dumps ( { " buckets " : self . buckets , " bucket_ratios " : self . _bucket_ratios } )
2022-10-29 10:26:03 -06:00
def get_batch_iterator ( self ) - > Generator [ Tuple [ Tuple [ int , int ] , List [ int ] ] , None , None ] :
"""
Generator that provides batches where the images in a batch fall on the same bucket
Each element generated will be :
( ( w , h ) , [ image1 , image2 , . . . , image { batch_size } ] )
where each image is an index into the dataset
: return :
"""
max_bucket_len = max ( len ( b ) for b in self . bucket_data . values ( ) )
index_schedule = list ( range ( max_bucket_len ) )
random . shuffle ( index_schedule )
bucket_len_table = {
b : len ( self . bucket_data [ b ] ) for b in self . buckets
}
bucket_schedule = [ ]
for i , b in enumerate ( self . buckets ) :
bucket_schedule . extend ( [ i ] * ( bucket_len_table [ b ] / / self . batch_size ) )
random . shuffle ( bucket_schedule )
bucket_pos = {
b : 0 for b in self . buckets
}
total_generated_by_bucket = {
b : 0 for b in self . buckets
}
for bucket_index in bucket_schedule :
b = self . buckets [ bucket_index ]
i = bucket_pos [ b ]
bucket_len = bucket_len_table [ b ]
batch = [ ]
while len ( batch ) != self . batch_size :
# advance in the schedule until we find an index that is contained in the bucket
k = index_schedule [ i ]
if k < bucket_len :
entry = self . bucket_data [ b ] [ k ]
batch . append ( entry )
i + = 1
total_generated_by_bucket [ b ] + = self . batch_size
bucket_pos [ b ] = i
yield [ idx for idx in batch ]
def fill_buckets ( self ) :
entries = self . store . entries_iterator ( )
total_dropped = 0
for entry , index in tqdm . tqdm ( entries , total = len ( self . store ) ) :
if not self . _process_entry ( entry , index ) :
total_dropped + = 1
for b , values in self . bucket_data . items ( ) :
# shuffle the entries for extra randomness and to make sure dropped elements are also random
random . shuffle ( values )
# make sure the buckets have an exact number of elements for the batch
to_drop = len ( values ) % self . batch_size
self . bucket_data [ b ] = list ( values [ : len ( values ) - to_drop ] )
total_dropped + = to_drop
self . total_dropped = total_dropped
def _process_entry ( self , entry : Image . Image , index : int ) - > bool :
aspect = entry . width / entry . height
if aspect > self . max_ratio or ( 1 / aspect ) > self . max_ratio :
return False
best_bucket = self . _bucket_interp ( aspect )
if best_bucket is None :
return False
bucket = self . buckets [ round ( float ( best_bucket ) ) ]
self . bucket_data [ bucket ] . append ( index )
2022-11-01 14:28:12 -06:00
del entry
2022-10-29 10:26:03 -06:00
return True
class AspectBucketSampler ( torch . utils . data . Sampler ) :
2022-11-01 14:28:12 -06:00
def __init__ ( self , bucket : AspectBucket , num_replicas : int = 1 , rank : int = 0 ) :
2022-10-29 10:26:03 -06:00
super ( ) . __init__ ( None )
self . bucket = bucket
2022-11-01 14:28:12 -06:00
self . num_replicas = num_replicas
self . rank = rank
2022-10-29 10:26:03 -06:00
def __iter__ ( self ) :
2022-11-01 14:28:12 -06:00
# subsample the bucket to only include the elements that are assigned to this rank
indices = self . bucket . get_batch_iterator ( )
indices = list ( indices ) [ self . rank : : self . num_replicas ]
return iter ( indices )
2022-10-29 10:26:03 -06:00
def __len__ ( self ) :
2022-11-01 14:28:12 -06:00
return self . bucket . get_batch_count ( ) / / self . num_replicas
2022-10-29 10:26:03 -06:00
class AspectDataset ( torch . utils . data . Dataset ) :
def __init__ ( self , store : ImageStore , tokenizer : CLIPTokenizer , ucg : float = 0.1 ) :
self . store = store
self . tokenizer = tokenizer
self . ucg = ucg
self . transforms = torchvision . transforms . Compose ( [
torchvision . transforms . RandomHorizontalFlip ( p = 0.5 ) ,
torchvision . transforms . ToTensor ( ) ,
torchvision . transforms . Normalize ( [ 0.5 ] , [ 0.5 ] ) ,
] )
def __len__ ( self ) :
return len ( self . store )
def __getitem__ ( self , item : int ) :
return_dict = { ' pixel_values ' : None , ' input_ids ' : None }
image_file = self . store . get_image ( item )
return_dict [ ' pixel_values ' ] = self . transforms ( image_file )
if random . random ( ) > self . ucg :
caption_file = self . store . get_caption ( item )
else :
caption_file = ' '
return_dict [ ' input_ids ' ] = self . tokenizer ( caption_file , max_length = self . tokenizer . model_max_length , padding = ' do_not_pad ' , truncation = True ) . input_ids
return return_dict
def collate_fn ( self , examples ) :
pixel_values = torch . stack ( [ example [ ' pixel_values ' ] for example in examples if example is not None ] )
pixel_values . to ( memory_format = torch . contiguous_format ) . float ( )
input_ids = [ example [ ' input_ids ' ] for example in examples if example is not None ]
padded_tokens = self . tokenizer . pad ( { ' input_ids ' : input_ids } , return_tensors = ' pt ' , padding = True )
return {
' pixel_values ' : pixel_values ,
' input_ids ' : padded_tokens . input_ids ,
' attention_mask ' : padded_tokens . attention_mask ,
}
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel :
"""
Exponential Moving Average of models weights
"""
def __init__ ( self , parameters : Iterable [ torch . nn . Parameter ] , decay = 0.9999 ) :
parameters = list ( parameters )
self . shadow_params = [ p . clone ( ) . detach ( ) for p in parameters ]
self . decay = decay
self . optimization_step = 0
def get_decay ( self , optimization_step ) :
"""
Compute the decay factor for the exponential moving average .
"""
value = ( 1 + optimization_step ) / ( 10 + optimization_step )
return 1 - min ( self . decay , value )
@torch.no_grad ( )
def step ( self , parameters ) :
parameters = list ( parameters )
self . optimization_step + = 1
self . decay = self . get_decay ( self . optimization_step )
for s_param , param in zip ( self . shadow_params , parameters ) :
if param . requires_grad :
tmp = self . decay * ( s_param - param )
s_param . sub_ ( tmp )
else :
s_param . copy_ ( param )
torch . cuda . empty_cache ( )
def copy_to ( self , parameters : Iterable [ torch . nn . Parameter ] ) - > None :
"""
Copy current averaged parameters into given collection of parameters .
Args :
parameters : Iterable of ` torch . nn . Parameter ` ; the parameters to be
updated with the stored moving averages . If ` None ` , the
parameters with which this ` ExponentialMovingAverage ` was
initialized will be used .
"""
parameters = list ( parameters )
for s_param , param in zip ( self . shadow_params , parameters ) :
param . data . copy_ ( s_param . data )
def to ( self , device = None , dtype = None ) - > None :
r """ Move internal buffers of the ExponentialMovingAverage to `device`.
Args :
device : like ` device ` argument to ` torch . Tensor . to `
"""
# .to() on the tensors handles None correctly
self . shadow_params = [
p . to ( device = device , dtype = dtype ) if p . is_floating_point ( ) else p . to ( device = device )
for p in self . shadow_params
]
def main ( ) :
2022-11-01 14:28:12 -06:00
rank = get_rank ( )
world_size = get_world_size ( )
torch . cuda . set_device ( rank )
if args . hf_token is None :
args . hf_token = os . environ [ ' HF_API_TOKEN ' ]
if rank == 0 :
os . makedirs ( args . output_path , exist_ok = True )
# remove hf_token from args so sneaky people don't steal it from the wandb logs
sanitized_args = { k : v for k , v in vars ( args ) . items ( ) if k not in [ ' hf_token ' ] }
run = wandb . init ( project = args . project_id , name = args . run_name , config = sanitized_args , dir = args . output_path + ' /wandb ' )
device = torch . device ( ' cuda ' )
2022-10-29 10:26:03 -06:00
print ( " DEVICE: " , device )
# setup fp16 stuff
scaler = torch . cuda . amp . GradScaler ( enabled = args . fp16 )
# Set seed
torch . manual_seed ( args . seed )
print ( ' RANDOM SEED: ' , args . seed )
tokenizer = CLIPTokenizer . from_pretrained ( args . model , subfolder = ' tokenizer ' , use_auth_token = args . hf_token )
text_encoder = CLIPTextModel . from_pretrained ( args . model , subfolder = ' text_encoder ' , use_auth_token = args . hf_token )
vae = AutoencoderKL . from_pretrained ( args . model , subfolder = ' vae ' , use_auth_token = args . hf_token )
unet = UNet2DConditionModel . from_pretrained ( args . model , subfolder = ' unet ' , use_auth_token = args . hf_token )
# Freeze vae and text_encoder
vae . requires_grad_ ( False )
text_encoder . requires_grad_ ( False )
if args . gradient_checkpointing :
unet . enable_gradient_checkpointing ( )
if args . use_8bit_adam : # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try :
import bitsandbytes as bnb
optimizer_cls = bnb . optim . AdamW8bit
except :
print ( ' bitsandbytes not supported, using regular Adam optimizer ' )
optimizer_cls = torch . optim . AdamW
else :
optimizer_cls = torch . optim . AdamW
optimizer = optimizer_cls (
unet . parameters ( ) ,
lr = args . lr ,
betas = ( args . adam_beta1 , args . adam_beta2 ) ,
eps = args . adam_epsilon ,
weight_decay = args . adam_weight_decay ,
)
noise_scheduler = DDPMScheduler (
beta_start = 0.00085 ,
beta_end = 0.012 ,
beta_schedule = ' scaled_linear ' ,
num_train_timesteps = 1000 ,
tensor_format = ' pt '
)
# load dataset
store = ImageStore ( args . dataset )
dataset = AspectDataset ( store , tokenizer )
2022-11-05 07:55:56 -06:00
bucket = AspectBucket ( store , args . num_buckets , args . batch_size , args . bucket_side_min , args . bucket_side_max , 64 , args . resolution * args . resolution , 2.0 )
2022-11-01 14:28:12 -06:00
sampler = AspectBucketSampler ( bucket = bucket , num_replicas = world_size , rank = rank )
2022-10-29 10:26:03 -06:00
print ( f ' STORE_LEN: { len ( store ) } ' )
2022-11-05 07:55:56 -06:00
if args . output_bucket_info :
print ( bucket . get_bucket_info ( ) )
exit ( 0 )
2022-10-29 10:26:03 -06:00
train_dataloader = torch . utils . data . DataLoader (
dataset ,
batch_sampler = sampler ,
2022-11-01 14:28:12 -06:00
num_workers = 0 ,
2022-10-29 10:26:03 -06:00
collate_fn = dataset . collate_fn
)
lr_scheduler = get_scheduler (
' constant ' ,
optimizer = optimizer
)
weight_dtype = torch . float16 if args . fp16 else torch . float32
# move models to device
vae = vae . to ( device , dtype = weight_dtype )
unet = unet . to ( device , dtype = torch . float32 )
text_encoder = text_encoder . to ( device , dtype = weight_dtype )
2022-11-01 14:28:12 -06:00
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
2022-10-29 10:26:03 -06:00
# create ema
if args . use_ema :
ema_unet = EMAModel ( unet . parameters ( ) )
print ( get_gpu_ram ( ) )
num_steps_per_epoch = len ( train_dataloader )
progress_bar = tqdm . tqdm ( range ( args . epochs * num_steps_per_epoch ) , desc = " Total Steps " , leave = False )
global_step = 0
def save_checkpoint ( ) :
2022-11-01 14:28:12 -06:00
if rank == 0 :
if args . use_ema :
ema_unet . copy_to ( unet . parameters ( ) )
pipeline = StableDiffusionPipeline (
text_encoder = text_encoder ,
vae = vae ,
unet = unet ,
tokenizer = tokenizer ,
scheduler = PNDMScheduler (
beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = " scaled_linear " , skip_prk_steps = True
) ,
safety_checker = StableDiffusionSafetyChecker . from_pretrained ( " CompVis/stable-diffusion-safety-checker " ) ,
feature_extractor = CLIPFeatureExtractor . from_pretrained ( " openai/clip-vit-base-patch32 " ) ,
)
pipeline . save_pretrained ( args . output_path )
# barrier
torch . distributed . barrier ( )
2022-10-29 10:26:03 -06:00
# train!
2022-11-01 14:28:12 -06:00
loss = torch . tensor ( 0.0 , device = device , dtype = weight_dtype )
2022-10-29 10:26:03 -06:00
for epoch in range ( args . epochs ) :
unet . train ( )
train_loss = 0.0
for step , batch in enumerate ( train_dataloader ) :
2022-11-01 14:28:12 -06:00
b_start = time . perf_counter ( )
2022-10-29 10:26:03 -06:00
latents = vae . encode ( batch [ ' pixel_values ' ] . to ( device , dtype = weight_dtype ) ) . latent_dist . sample ( )
latents = latents * 0.18215
# Sample noise
noise = torch . randn_like ( latents )
bsz = latents . shape [ 0 ]
# Sample a random timestep for each image
timesteps = torch . randint ( 0 , noise_scheduler . num_train_timesteps , ( bsz , ) , device = latents . device )
timesteps = timesteps . long ( )
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
# Get the text embedding for conditioning
2022-11-02 17:20:46 -06:00
encoder_hidden_states = text_encoder ( batch [ ' input_ids ' ] . to ( device ) , output_hidden_states = True )
if args . clip_penultimate :
encoder_hidden_states = text_encoder . text_model . final_layer_norm ( encoder_hidden_states [ ' hidden_states ' ] [ - 2 ] )
else :
encoder_hidden_states = encoder_hidden_states . last_hidden_state
2022-10-29 10:26:03 -06:00
# Predict the noise residual and compute loss
with torch . autocast ( ' cuda ' , enabled = args . fp16 ) :
noise_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
loss = torch . nn . functional . mse_loss ( noise_pred . float ( ) , noise . float ( ) , reduction = " mean " )
2022-11-01 14:28:12 -06:00
# Backprop and all reduce
2022-10-29 10:26:03 -06:00
scaler . scale ( loss ) . backward ( )
scaler . step ( optimizer )
scaler . update ( )
lr_scheduler . step ( )
optimizer . zero_grad ( )
# Update EMA
if args . use_ema :
ema_unet . step ( unet . parameters ( ) )
2022-11-01 14:28:12 -06:00
# perf
b_end = time . perf_counter ( )
seconds_per_step = b_end - b_start
steps_per_second = 1 / seconds_per_step
rank_images_per_second = args . batch_size * steps_per_second
world_images_per_second = rank_images_per_second * world_size
samples_seen = global_step * args . batch_size * world_size
# All reduce loss
torch . distributed . all_reduce ( loss , op = torch . distributed . ReduceOp . SUM )
if rank == 0 :
progress_bar . update ( 1 )
global_step + = 1
logs = {
" train/loss " : loss . detach ( ) . item ( ) / world_size ,
" train/lr " : lr_scheduler . get_last_lr ( ) [ 0 ] ,
" train/epoch " : epoch ,
" train/samples_seen " : samples_seen ,
" perf/rank_samples_per_second " : rank_images_per_second ,
" perf/global_samples_per_second " : world_images_per_second ,
}
progress_bar . set_postfix ( logs )
run . log ( logs )
2022-10-29 10:26:03 -06:00
if global_step % args . save_steps == 0 :
save_checkpoint ( )
if global_step % args . image_log_steps == 0 :
2022-11-01 14:28:12 -06:00
if rank == 0 :
# get prompt from random batch
prompt = tokenizer . decode ( batch [ ' input_ids ' ] [ random . randint ( 0 , len ( batch [ ' input_ids ' ] ) - 1 ) ] . tolist ( ) )
pipeline = StableDiffusionPipeline (
text_encoder = text_encoder ,
vae = vae ,
unet = unet ,
tokenizer = tokenizer ,
scheduler = PNDMScheduler (
beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = " scaled_linear " , skip_prk_steps = True
) ,
safety_checker = None , # display safety checker to save memory
feature_extractor = CLIPFeatureExtractor . from_pretrained ( " openai/clip-vit-base-patch32 " ) ,
) . to ( device )
# inference
images = [ ]
with torch . no_grad ( ) :
with torch . autocast ( ' cuda ' , enabled = args . fp16 ) :
for _ in range ( args . image_log_amount ) :
images . append ( wandb . Image ( pipeline ( prompt ) . images [ 0 ] , caption = prompt ) )
# log images under single caption
run . log ( { ' images ' : images } )
# cleanup so we don't run out of memory
del pipeline
gc . collect ( )
torch . distributed . barrier ( )
2022-10-29 10:26:03 -06:00
2022-11-01 14:28:12 -06:00
if rank == 0 :
save_checkpoint ( )
torch . distributed . barrier ( )
cleanup ( )
2022-10-29 10:26:03 -06:00
print ( get_gpu_ram ( ) )
print ( ' Done! ' )
2022-11-01 14:28:12 -06:00
if __name__ == " __main__ " :
setup ( )
2022-10-29 10:26:03 -06:00
main ( )