2022-12-17 20:32:48 -07:00
"""
2023-01-22 19:43:03 -07:00
Copyright [ 2022 - 2023 ] Victor C Hall
2022-12-17 20:32:48 -07:00
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2023-01-21 23:15:50 -07:00
2022-12-17 20:32:48 -07:00
import os
2023-01-29 19:11:34 -07:00
import pprint
2022-12-17 20:32:48 -07:00
import sys
import math
import signal
import argparse
import logging
2023-01-30 03:26:11 -07:00
import threading
2022-12-17 20:32:48 -07:00
import time
2022-12-20 01:30:42 -07:00
import gc
2022-12-27 12:25:32 -07:00
import random
2023-01-12 14:32:37 -07:00
import traceback
2023-01-18 11:07:05 -07:00
import shutil
2023-09-10 13:37:47 -06:00
from typing import Optional
2022-12-17 20:32:48 -07:00
2023-01-11 22:40:49 -07:00
import torch . nn . functional as F
2023-04-29 20:56:10 -06:00
from torch . cuda . amp import autocast
2022-12-17 20:32:48 -07:00
2023-02-18 07:51:50 -07:00
from colorama import Fore , Style
2022-12-17 20:32:48 -07:00
import numpy as np
import itertools
import torch
import datetime
import json
2023-04-30 09:52:30 -06:00
from tqdm . auto import tqdm
2022-12-17 20:32:48 -07:00
2023-02-18 07:51:50 -07:00
from diffusers import StableDiffusionPipeline , AutoencoderKL , UNet2DConditionModel , DDIMScheduler , DDPMScheduler , \
DPMSolverMultistepScheduler
2022-12-17 20:32:48 -07:00
#from diffusers.models import AttentionBlock
from diffusers . optimization import get_scheduler
from diffusers . utils . import_utils import is_xformers_available
from transformers import CLIPTextModel , CLIPTokenizer
#from accelerate import Accelerator
from accelerate . utils import set_seed
import wandb
2023-03-25 18:09:06 -06:00
import webbrowser
2022-12-17 20:32:48 -07:00
from torch . utils . tensorboard import SummaryWriter
2023-01-29 19:11:34 -07:00
from data . data_loader import DataLoaderMultiAspect
2022-12-17 20:32:48 -07:00
2023-02-07 09:32:54 -07:00
from data . every_dream import EveryDreamBatch , build_torch_dataloader
from data . every_dream_validation import EveryDreamValidator
2023-06-04 17:01:59 -06:00
from data . image_train_item import ImageTrainItem , DEFAULT_BATCH_ID
2023-01-12 14:32:37 -07:00
from utils . huggingface_downloader import try_download_model_from_hf
2023-01-12 18:59:04 -07:00
from utils . convert_diff_to_ckpt import convert as converter
2023-02-18 12:18:21 -07:00
from utils . isolate_rng import isolate_rng
2023-04-29 16:15:25 -06:00
from utils . check_git import check_git
2023-04-29 20:56:10 -06:00
from optimizer . optimizers import EveryDreamOptimizer
2023-09-06 04:38:52 -06:00
from copy import deepcopy
2023-02-18 12:18:21 -07:00
2023-02-08 03:28:45 -07:00
if torch . cuda . is_available ( ) :
from utils . gpu import GPU
2023-01-29 19:11:34 -07:00
import data . aspects as aspects
import data . resolver as resolver
2023-02-18 07:51:50 -07:00
from utils . sample_generator import SampleGenerator
2022-12-17 20:32:48 -07:00
_SIGTERM_EXIT_CODE = 130
2022-12-18 15:24:54 -07:00
_VERY_LARGE_NUMBER = 1e9
2022-12-17 20:32:48 -07:00
2023-01-23 13:10:04 -07:00
def get_hf_ckpt_cache_path ( ckpt_path ) :
return os . path . join ( " ckpt_cache " , os . path . basename ( ckpt_path ) )
2022-12-17 20:32:48 -07:00
def convert_to_hf ( ckpt_path ) :
2023-01-12 14:32:37 -07:00
2023-01-23 13:10:04 -07:00
hf_cache = get_hf_ckpt_cache_path ( ckpt_path )
2023-05-26 19:54:02 -06:00
from utils . unet_utils import get_attn_yaml
2022-12-17 20:32:48 -07:00
2023-04-13 21:31:22 -06:00
if os . path . isfile ( ckpt_path ) :
2022-12-17 20:32:48 -07:00
if not os . path . exists ( hf_cache ) :
os . makedirs ( hf_cache )
logging . info ( f " Converting { ckpt_path } to Diffusers format " )
2022-12-18 15:24:54 -07:00
try :
import utils . convert_original_stable_diffusion_to_diffusers as convert
convert . convert ( ckpt_path , f " ckpt_cache/ { ckpt_path } " )
except :
2023-01-20 07:42:24 -07:00
logging . info ( " Please manually convert the checkpoint to Diffusers format (one time setup), see readme. " )
2022-12-18 15:24:54 -07:00
exit ( )
2022-12-27 12:25:32 -07:00
else :
logging . info ( f " Found cached checkpoint at { hf_cache } " )
2023-04-13 21:31:22 -06:00
2023-01-25 18:55:24 -07:00
is_sd1attn , yaml = get_attn_yaml ( hf_cache )
2023-01-18 11:07:05 -07:00
return hf_cache , is_sd1attn , yaml
2022-12-17 20:32:48 -07:00
elif os . path . isdir ( hf_cache ) :
2023-01-25 18:55:24 -07:00
is_sd1attn , yaml = get_attn_yaml ( hf_cache )
2023-01-18 11:07:05 -07:00
return hf_cache , is_sd1attn , yaml
2022-12-17 20:32:48 -07:00
else :
2023-01-25 18:55:24 -07:00
is_sd1attn , yaml = get_attn_yaml ( ckpt_path )
2023-01-18 11:07:05 -07:00
return ckpt_path , is_sd1attn , yaml
2022-12-17 20:32:48 -07:00
2023-09-10 13:37:47 -06:00
class EveryDreamTrainingState :
def __init__ ( self ,
optimizer : EveryDreamOptimizer ,
train_batch : EveryDreamBatch ,
unet : UNet2DConditionModel ,
text_encoder : CLIPTextModel ,
tokenizer : CLIPTokenizer ,
scheduler ,
vae : AutoencoderKL ,
unet_ema : Optional [ UNet2DConditionModel ] ,
text_encoder_ema : Optional [ CLIPTextModel ]
) :
self . optimizer = optimizer
self . train_batch = train_batch
self . unet = unet
self . text_encoder = text_encoder
self . tokenizer = tokenizer
self . scheduler = scheduler
self . vae = vae
2023-09-17 12:17:54 -06:00
self . unet_ema = unet_ema
self . text_encoder_ema = text_encoder_ema
2023-09-10 13:37:47 -06:00
@torch.no_grad ( )
def save_model ( save_path , ed_state : EveryDreamTrainingState , global_step : int , save_ckpt_dir , yaml_name ,
save_full_precision = False , save_optimizer_flag = False , save_ckpt = True ) :
"""
Save the model to disk
"""
def save_ckpt_file ( diffusers_model_path , sd_ckpt_path ) :
nonlocal save_ckpt_dir
nonlocal save_full_precision
nonlocal yaml_name
if save_ckpt_dir is not None :
sd_ckpt_full = os . path . join ( save_ckpt_dir , sd_ckpt_path )
else :
sd_ckpt_full = os . path . join ( os . curdir , sd_ckpt_path )
save_ckpt_dir = os . curdir
half = not save_full_precision
logging . info ( f " * Saving SD model to { sd_ckpt_full } " )
converter ( model_path = diffusers_model_path , checkpoint_path = sd_ckpt_full , half = half )
if yaml_name and yaml_name != " v1-inference.yaml " :
yaml_save_path = f " { os . path . join ( save_ckpt_dir , os . path . basename ( diffusers_model_path ) ) } .yaml "
logging . info ( f " * Saving yaml to { yaml_save_path } " )
shutil . copyfile ( yaml_name , yaml_save_path )
if global_step is None or global_step == 0 :
logging . warning ( " No model to save, something likely blew up on startup, not saving " )
return
if args . ema_decay_rate != None :
pipeline_ema = StableDiffusionPipeline (
vae = ed_state . vae ,
text_encoder = ed_state . text_encoder_ema ,
tokenizer = ed_state . tokenizer ,
unet = ed_state . unet_ema ,
scheduler = ed_state . scheduler ,
safety_checker = None , # save vram
requires_safety_checker = None , # avoid nag
feature_extractor = None , # must be none of no safety checker
)
diffusers_model_path = save_path + " _ema "
logging . info ( f " * Saving diffusers EMA model to { diffusers_model_path } " )
pipeline_ema . save_pretrained ( diffusers_model_path )
if save_ckpt :
sd_ckpt_path_ema = f " { os . path . basename ( save_path ) } _ema.ckpt "
save_ckpt_file ( diffusers_model_path , sd_ckpt_path_ema )
pipeline = StableDiffusionPipeline (
vae = ed_state . vae ,
text_encoder = ed_state . text_encoder ,
tokenizer = ed_state . tokenizer ,
unet = ed_state . unet ,
scheduler = ed_state . scheduler ,
safety_checker = None , # save vram
requires_safety_checker = None , # avoid nag
feature_extractor = None , # must be none of no safety checker
)
diffusers_model_path = save_path
logging . info ( f " * Saving diffusers model to { diffusers_model_path } " )
pipeline . save_pretrained ( diffusers_model_path )
if save_ckpt :
sd_ckpt_path = f " { os . path . basename ( save_path ) } .ckpt "
save_ckpt_file ( diffusers_model_path , sd_ckpt_path )
if save_optimizer_flag :
logging . info ( f " Saving optimizer state to { save_path } " )
ed_state . optimizer . save ( save_path )
2022-12-17 20:32:48 -07:00
def setup_local_logger ( args ) :
"""
configures logger with file and console logging , logs args , and returns the datestamp
"""
2022-12-18 18:44:33 -07:00
log_path = args . logdir
2023-01-08 16:52:39 -07:00
2022-12-17 20:32:48 -07:00
if not os . path . exists ( log_path ) :
os . makedirs ( log_path )
json_config = json . dumps ( vars ( args ) , indent = 2 )
datetimestamp = datetime . datetime . now ( ) . strftime ( " % Y % m %d - % H % M % S " )
2023-01-03 13:10:57 -07:00
with open ( os . path . join ( log_path , f " { args . project_name } - { datetimestamp } _cfg.json " ) , " w " ) as f :
2023-01-03 12:27:26 -07:00
f . write ( f " { json_config } " )
2022-12-17 20:32:48 -07:00
2023-01-03 12:27:26 -07:00
logfilename = os . path . join ( log_path , f " { args . project_name } - { datetimestamp } .log " )
2023-01-08 16:52:39 -07:00
print ( f " logging to { logfilename } " )
2022-12-17 20:32:48 -07:00
logging . basicConfig ( filename = logfilename ,
level = logging . INFO ,
format = " %(asctime)s %(message)s " ,
datefmt = " % m/ %d / % Y % I: % M: % S % p " ,
)
2023-03-25 18:09:06 -06:00
console_handler = logging . StreamHandler ( sys . stdout )
2023-04-16 12:53:34 -06:00
console_handler . addFilter ( lambda msg : " Palette images with Transparency expressed in bytes " not in msg . getMessage ( ) )
2023-03-25 18:09:06 -06:00
logging . getLogger ( ) . addHandler ( console_handler )
import warnings
warnings . filterwarnings ( " ignore " , message = " UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images " )
#from PIL import Image
2023-01-08 16:52:39 -07:00
2022-12-17 20:32:48 -07:00
return datetimestamp
2023-04-29 20:56:10 -06:00
# def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
# """
# Saves the optimizer state
# """
# torch.save(optimizer.state_dict(), path)
2022-12-17 20:32:48 -07:00
2023-04-29 20:56:10 -06:00
# def load_optimizer(optimizer: torch.optim.Optimizer, path: str):
# """
# Loads the optimizer state
# """
# optimizer.load_state_dict(torch.load(path))
2022-12-17 20:32:48 -07:00
def get_gpu_memory ( nvsmi ) :
"""
returns the gpu memory usage
"""
gpu_query = nvsmi . DeviceQuery ( ' memory.used, memory.total ' )
gpu_used_mem = int ( gpu_query [ ' gpu ' ] [ 0 ] [ ' fb_memory_usage ' ] [ ' used ' ] )
gpu_total_mem = int ( gpu_query [ ' gpu ' ] [ 0 ] [ ' fb_memory_usage ' ] [ ' total ' ] )
return gpu_used_mem , gpu_total_mem
def append_epoch_log ( global_step : int , epoch_pbar , gpu , log_writer , * * logs ) :
"""
updates the vram usage for the epoch
"""
2023-02-08 05:46:58 -07:00
if gpu is not None :
gpu_used_mem , gpu_total_mem = gpu . get_gpu_memory ( )
log_writer . add_scalar ( " performance/vram " , gpu_used_mem , global_step )
epoch_mem_color = Style . RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem :
epoch_mem_color = Fore . LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem :
epoch_mem_color = Fore . LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem :
epoch_mem_color = Fore . LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem :
epoch_mem_color = Fore . LIGHTBLUE_EX
if logs is not None :
epoch_pbar . set_postfix ( * * logs , vram = f " { epoch_mem_color } { gpu_used_mem } / { gpu_total_mem } MB { Style . RESET_ALL } gs: { global_step } " )
2022-12-17 20:32:48 -07:00
2023-01-01 08:45:18 -07:00
def set_args_12gb ( args ) :
logging . info ( " Setting args to 12GB mode " )
2023-04-13 21:31:22 -06:00
if not args . gradient_checkpointing :
2023-01-11 22:40:49 -07:00
logging . info ( " - Overiding gradient checkpointing to True " )
2023-01-01 08:45:18 -07:00
args . gradient_checkpointing = True
2023-03-10 19:35:47 -07:00
if args . batch_size > 2 :
logging . info ( " - Overiding batch size to max 2 " )
args . batch_size = 2
2023-01-01 08:45:18 -07:00
args . grad_accum = 1
2023-01-11 22:40:49 -07:00
if args . resolution > 512 :
2023-03-10 19:35:47 -07:00
logging . info ( " - Overiding resolution to max 512 " )
2023-01-01 08:45:18 -07:00
args . resolution = 512
2023-09-07 10:53:20 -06:00
def find_last_checkpoint ( logdir , is_ema = False ) :
2022-12-17 20:32:48 -07:00
"""
2023-01-08 16:52:39 -07:00
Finds the last checkpoint in the logdir , recursively
2022-12-17 20:32:48 -07:00
"""
2023-01-08 16:52:39 -07:00
last_ckpt = None
last_date = None
2022-12-17 20:32:48 -07:00
2023-01-08 16:52:39 -07:00
for root , dirs , files in os . walk ( logdir ) :
for file in files :
if os . path . basename ( file ) == " model_index.json " :
2023-09-07 10:53:20 -06:00
if is_ema and ( not root . endswith ( " _ema " ) ) :
continue
elif ( not is_ema ) and root . endswith ( " _ema " ) :
continue
2023-01-08 16:52:39 -07:00
curr_date = os . path . getmtime ( os . path . join ( root , file ) )
if last_date is None or curr_date > last_date :
last_date = curr_date
last_ckpt = root
assert last_ckpt , f " Could not find last checkpoint in logdir: { logdir } "
assert " errored " not in last_ckpt , f " Found last checkpoint: { last_ckpt } , but it was errored, cancelling "
2023-01-02 15:33:31 -07:00
2023-01-08 16:52:39 -07:00
print ( f " { Fore . LIGHTCYAN_EX } Found last checkpoint: { last_ckpt } , resuming { Style . RESET_ALL } " )
2023-01-02 15:33:31 -07:00
2023-01-08 16:52:39 -07:00
return last_ckpt
def setup_args ( args ) :
"""
Sets defaults for missing args ( possible if missing from json config )
Forces some args to be set based on others for compatibility reasons
"""
2023-03-10 19:35:47 -07:00
if args . disable_amp :
logging . warning ( f " { Fore . LIGHTYELLOW_EX } Disabling AMP, not recommended. { Style . RESET_ALL } " )
2023-03-15 09:23:47 -06:00
args . amp = False
2023-03-10 19:35:47 -07:00
else :
args . amp = True
2023-01-15 20:07:37 -07:00
if args . disable_unet_training and args . disable_textenc_training :
raise ValueError ( " Both unet and textenc are disabled, nothing to train " )
2023-01-08 16:52:39 -07:00
if args . resume_ckpt == " findlast " :
logging . info ( f " { Fore . LIGHTCYAN_EX } Finding last checkpoint in logdir: { args . logdir } { Style . RESET_ALL } " )
# find the last checkpoint in the logdir
args . resume_ckpt = find_last_checkpoint ( args . logdir )
2023-09-12 19:37:27 -06:00
if ( args . ema_resume_model != None ) and ( args . ema_resume_model == " findlast " ) :
2023-09-07 10:53:20 -06:00
logging . info ( f " { Fore . LIGHTCYAN_EX } Finding last EMA decay checkpoint in logdir: { args . logdir } { Style . RESET_ALL } " )
2023-09-12 19:37:27 -06:00
args . ema_resume_model = find_last_checkpoint ( args . logdir , is_ema = True )
2023-09-07 10:53:20 -06:00
2023-01-08 16:52:39 -07:00
if args . lowvram :
set_args_12gb ( args )
2023-01-06 17:12:52 -07:00
if not args . shuffle_tags :
args . shuffle_tags = False
2023-01-06 14:33:33 -07:00
2023-01-01 08:45:18 -07:00
args . clip_skip = max ( min ( 4 , args . clip_skip ) , 0 )
2023-01-15 20:07:37 -07:00
2023-02-25 13:05:22 -07:00
if args . useadam8bit :
logging . warning ( f " { Fore . LIGHTYELLOW_EX } Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway { Style . RESET_ALL } " )
2022-12-18 18:44:33 -07:00
if args . ckpt_every_n_minutes is None and args . save_every_n_epochs is None :
2022-12-27 12:25:32 -07:00
logging . info ( f " { Fore . LIGHTCYAN_EX } No checkpoint saving specified, defaulting to every 20 minutes. { Style . RESET_ALL } " )
2022-12-18 18:44:33 -07:00
args . ckpt_every_n_minutes = 20
2022-12-17 20:32:48 -07:00
2022-12-18 15:24:54 -07:00
if args . ckpt_every_n_minutes is None or args . ckpt_every_n_minutes < 1 :
args . ckpt_every_n_minutes = _VERY_LARGE_NUMBER
if args . save_every_n_epochs is None or args . save_every_n_epochs < 1 :
args . save_every_n_epochs = _VERY_LARGE_NUMBER
2023-01-15 20:07:37 -07:00
2022-12-18 15:24:54 -07:00
if args . save_every_n_epochs < _VERY_LARGE_NUMBER and args . ckpt_every_n_minutes < _VERY_LARGE_NUMBER :
2023-01-02 15:33:31 -07:00
logging . warning ( f " { Fore . LIGHTYELLOW_EX } ** Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints { Style . RESET_ALL } " )
logging . warning ( f " { Fore . LIGHTYELLOW_EX } ** save_every_n_epochs: { args . save_every_n_epochs } , ckpt_every_n_minutes: { args . ckpt_every_n_minutes } { Style . RESET_ALL } " )
2022-12-17 22:43:25 -07:00
2022-12-18 11:03:44 -07:00
if args . cond_dropout > 0.26 :
2023-01-02 15:33:31 -07:00
logging . warning ( f " { Fore . LIGHTYELLOW_EX } ** cond_dropout is set fairly high: { args . cond_dropout } , make sure this was intended { Style . RESET_ALL } " )
2022-12-17 22:43:25 -07:00
2022-12-27 12:25:32 -07:00
if args . grad_accum > 1 :
logging . info ( f " { Fore . CYAN } Batch size: { args . batch_size } , grad accum: { args . grad_accum } , ' effective ' batch size: { args . batch_size * args . grad_accum } { Style . RESET_ALL } " )
2023-01-08 16:52:39 -07:00
total_batch_size = args . batch_size * args . grad_accum
2023-01-12 11:17:38 -07:00
if args . save_ckpt_dir is not None and not os . path . exists ( args . save_ckpt_dir ) :
2023-01-12 10:25:09 -07:00
os . makedirs ( args . save_ckpt_dir )
2023-01-14 06:00:30 -07:00
if args . rated_dataset :
args . rated_dataset_target_dropout_percent = min ( max ( args . rated_dataset_target_dropout_percent , 0 ) , 100 )
logging . info ( logging . info ( f " { Fore . CYAN } * Activating rated images learning with a target rate of { args . rated_dataset_target_dropout_percent } % { Style . RESET_ALL } " ) )
2023-04-13 21:31:22 -06:00
args . aspects = aspects . get_aspect_buckets ( args . resolution )
2023-01-29 19:11:34 -07:00
2023-01-08 16:52:39 -07:00
return args
2023-01-29 19:11:34 -07:00
2023-04-29 07:33:01 -06:00
def report_image_train_item_problems ( log_folder : str , items : list [ ImageTrainItem ] , batch_size ) - > None :
2023-01-29 19:11:34 -07:00
undersized_items = [ item for item in items if item . is_undersized ]
if len ( undersized_items ) > 0 :
underized_log_path = os . path . join ( log_folder , " undersized_images.txt " )
logging . warning ( f " { Fore . LIGHTRED_EX } ** Some images are smaller than the target size, consider using larger images { Style . RESET_ALL } " )
logging . warning ( f " { Fore . LIGHTRED_EX } ** Check { underized_log_path } for more information. { Style . RESET_ALL } " )
2023-03-15 20:14:45 -06:00
with open ( underized_log_path , " w " , encoding = ' utf-8 ' ) as undersized_images_file :
2023-01-29 19:11:34 -07:00
undersized_images_file . write ( f " The following images are smaller than the target size, consider removing or sourcing a larger copy: " )
for undersized_item in undersized_items :
message = f " *** { undersized_item . pathname } with size: { undersized_item . image_size } is smaller than target size: { undersized_item . target_wh } \n "
undersized_images_file . write ( message )
2023-04-13 21:31:22 -06:00
2023-04-19 03:06:02 -06:00
# warn on underfilled aspect ratio buckets
# Intuition: if there are too few images to fill a batch, duplicates will be appended.
# this is not a problem for large image counts but can seriously distort training if there
# are just a handful of images for a given aspect ratio.
# at a dupe ratio of 0.5, all images in this bucket have effective multiplier 1.5,
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
warn_bucket_dupe_ratio = 0.5
2023-06-04 17:01:59 -06:00
def make_bucket_key ( item ) :
return ( item . batch_id , int ( item . target_wh [ 0 ] ) , int ( item . target_wh [ 1 ] ) )
ar_buckets = set ( make_bucket_key ( i ) for i in items )
2023-04-19 03:06:02 -06:00
for ar_bucket in ar_buckets :
2023-06-04 17:01:59 -06:00
count = len ( [ i for i in items if make_bucket_key ( i ) == ar_bucket ] )
2023-04-19 03:06:02 -06:00
runt_size = batch_size - ( count % batch_size )
bucket_dupe_ratio = runt_size / count
if bucket_dupe_ratio > warn_bucket_dupe_ratio :
2023-06-04 17:01:59 -06:00
aspect_ratio_rational = aspects . get_rational_aspect_ratio ( ( ar_bucket [ 1 ] , ar_bucket [ 2 ] ) )
2023-04-19 03:06:02 -06:00
aspect_ratio_description = f " { aspect_ratio_rational [ 0 ] } : { aspect_ratio_rational [ 1 ] } "
2023-06-04 17:01:59 -06:00
batch_id_description = " " if ar_bucket [ 0 ] == DEFAULT_BATCH_ID else f " for batch id ' { ar_bucket [ 0 ] } ' "
2023-04-19 03:06:02 -06:00
effective_multiplier = round ( 1 + bucket_dupe_ratio , 1 )
logging . warning ( f " * { Fore . LIGHTRED_EX } Aspect ratio bucket { ar_bucket } has only { count } "
f " images { Style . RESET_ALL } . At batch size { batch_size } this makes for an effective multiplier "
2023-04-29 07:33:01 -06:00
f " of { effective_multiplier } , which may cause problems. Consider adding { runt_size } or "
2023-06-04 17:01:59 -06:00
f " more images with aspect ratio { aspect_ratio_description } { batch_id_description } , or reducing your batch_size. " )
2023-04-19 03:06:02 -06:00
2023-04-29 07:33:01 -06:00
def resolve_image_train_items ( args : argparse . Namespace ) - > list [ ImageTrainItem ] :
2023-01-29 19:20:40 -07:00
logging . info ( f " * DLMA resolution { args . resolution } , buckets: { args . aspects } " )
logging . info ( " Preloading images... " )
2023-04-13 21:31:22 -06:00
2023-01-29 19:20:40 -07:00
resolved_items = resolver . resolve ( args . data_root , args )
image_paths = set ( map ( lambda item : item . pathname , resolved_items ) )
2023-04-13 21:31:22 -06:00
2023-01-29 19:20:40 -07:00
# Remove erroneous items
2023-04-29 07:33:01 -06:00
for item in resolved_items :
if item . error is not None :
logging . error ( f " { Fore . LIGHTRED_EX } *** Error opening { Fore . LIGHTYELLOW_EX } { item . pathname } { Fore . LIGHTRED_EX } to get metadata. File may be corrupt and will be skipped. { Style . RESET_ALL } " )
logging . error ( f " *** exception: { item . error } " )
2023-01-29 19:20:40 -07:00
image_train_items = [ item for item in resolved_items if item . error is None ]
2023-02-08 03:28:45 -07:00
print ( f " * Found { len ( image_paths ) } files in ' { args . data_root } ' " )
2023-04-13 21:31:22 -06:00
2023-01-29 19:20:40 -07:00
return image_train_items
2023-04-13 21:31:22 -06:00
2023-01-29 19:28:07 -07:00
def write_batch_schedule ( args : argparse . Namespace , log_folder : str , train_batch : EveryDreamBatch , epoch : int ) :
2023-01-29 19:11:34 -07:00
if args . write_schedule :
with open ( f " { log_folder } /ep { epoch } _batch_schedule.txt " , " w " , encoding = ' utf-8 ' ) as f :
for i in range ( len ( train_batch . image_train_items ) ) :
try :
item = train_batch . image_train_items [ i ]
f . write ( f " step: { int ( i / train_batch . batch_size ) : 05 } , wh: { item . target_wh } , r: { item . runt_size } , path: { item . pathname } \n " )
except Exception as e :
logging . error ( f " * Error writing to batch schedule for file path: { item . pathname } " )
2023-01-18 11:07:05 -07:00
2023-01-29 16:43:16 -07:00
def read_sample_prompts ( sample_prompts_file_path : str ) :
sample_prompts = [ ]
with open ( sample_prompts_file_path , " r " ) as f :
for line in f :
sample_prompts . append ( line . strip ( ) )
return sample_prompts
2023-02-26 17:11:42 -07:00
def log_args ( log_writer , args ) :
arglog = " args: \n "
for arg , value in sorted ( vars ( args ) . items ( ) ) :
arglog + = f " { arg } = { value } , "
log_writer . add_text ( " config " , arglog )
2023-09-07 10:53:20 -06:00
def update_ema ( model , ema_model , decay , default_device , ema_device ) :
2023-09-06 04:38:52 -06:00
with torch . no_grad ( ) :
2023-09-06 13:37:10 -06:00
original_model_on_proper_device = model
2023-09-06 04:38:52 -06:00
need_to_delete_original = False
2023-09-07 10:53:20 -06:00
if ema_device != default_device :
2023-09-06 04:38:52 -06:00
original_model_on_other_device = deepcopy ( model )
2023-09-07 10:53:20 -06:00
original_model_on_proper_device = original_model_on_other_device . to ( ema_device , dtype = model . dtype )
2023-09-06 04:38:52 -06:00
del original_model_on_other_device
need_to_delete_original = True
2023-09-06 13:37:10 -06:00
params = dict ( original_model_on_proper_device . named_parameters ( ) )
2023-09-06 04:38:52 -06:00
ema_params = dict ( ema_model . named_parameters ( ) )
for name in ema_params :
#ema_params[name].data.mul_(decay).add_(params[name].data, alpha=1 - decay)
ema_params [ name ] . data = ema_params [ name ] * decay + params [ name ] . data * ( 1.0 - decay )
if need_to_delete_original :
2023-09-06 13:37:10 -06:00
del ( original_model_on_proper_device )
2023-09-06 04:38:52 -06:00
def compute_snr ( timesteps , noise_scheduler ) :
"""
Computes SNR as per https : / / github . com / TiankaiHang / Min - SNR - Diffusion - Training / blob / 521 b624bd70c67cee4bdf49225915f5945a872e3 / guided_diffusion / gaussian_diffusion . py #L847-L849
"""
minimal_value = 1e-9
alphas_cumprod = noise_scheduler . alphas_cumprod
# Use .any() to check if any elements in the tensor are zero
if ( alphas_cumprod [ : - 1 ] == 0 ) . any ( ) :
logging . warning (
f " Alphas cumprod has zero elements! Resetting to { minimal_value } .. "
)
alphas_cumprod [ alphas_cumprod [ : - 1 ] == 0 ] = minimal_value
sqrt_alphas_cumprod = alphas_cumprod * * 0.5
sqrt_one_minus_alphas_cumprod = ( 1.0 - alphas_cumprod ) * * 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod . to ( device = timesteps . device ) [
timesteps
] . float ( )
while len ( sqrt_alphas_cumprod . shape ) < len ( timesteps . shape ) :
sqrt_alphas_cumprod = sqrt_alphas_cumprod [ . . . , None ]
alpha = sqrt_alphas_cumprod . expand ( timesteps . shape )
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod . to (
device = timesteps . device
) [ timesteps ] . float ( )
while len ( sqrt_one_minus_alphas_cumprod . shape ) < len ( timesteps . shape ) :
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod [ . . . , None ]
sigma = sqrt_one_minus_alphas_cumprod . expand ( timesteps . shape )
# Compute SNR, first without epsilon
snr = ( alpha / sigma ) * * 2
# Check if the first element in SNR tensor is zero
if torch . any ( snr == 0 ) :
snr [ snr == 0 ] = minimal_value
return snr
2023-09-07 10:53:20 -06:00
def load_train_json_from_file ( args , report_load = False ) :
2023-09-06 13:37:10 -06:00
try :
2023-09-07 10:53:20 -06:00
if report_load :
print ( f " Loading training config from { args . config } . " )
2023-09-06 13:37:10 -06:00
with open ( args . config , ' rt ' ) as f :
read_json = json . load ( f )
args . __dict__ . update ( read_json )
except Exception as config_read :
print ( f " Error on loading training config from { args . config } . " )
2023-01-29 16:43:16 -07:00
2023-01-08 16:52:39 -07:00
def main ( args ) :
"""
Main entry point
2023-04-13 21:31:22 -06:00
"""
2023-05-30 20:15:02 -06:00
if os . name == ' nt ' :
print ( " * Windows detected, disabling Triton " )
os . environ [ ' XFORMERS_FORCE_DISABLE_TRITON ' ] = " 1 "
2023-01-08 16:52:39 -07:00
log_time = setup_local_logger ( args )
args = setup_args ( args )
2023-04-16 16:48:44 -06:00
print ( f " Args: " )
pprint . pprint ( vars ( args ) )
2023-01-12 17:18:02 -07:00
2023-03-02 14:52:26 -07:00
if args . seed == - 1 :
args . seed = random . randint ( 0 , 2 * * 30 )
seed = args . seed
2023-01-15 20:07:37 -07:00
logging . info ( f " Seed: { seed } " )
2023-01-08 16:52:39 -07:00
set_seed ( seed )
2023-02-07 09:32:54 -07:00
if torch . cuda . is_available ( ) :
device = torch . device ( f " cuda: { args . gpuid } " )
2023-02-07 16:43:16 -07:00
gpu = GPU ( device )
2023-02-07 09:32:54 -07:00
torch . backends . cudnn . benchmark = True
else :
logging . warning ( " *** Running on CPU. This is for testing loading/config parsing code only. " )
device = ' cpu '
2023-02-08 05:46:58 -07:00
gpu = None
2023-01-08 16:52:39 -07:00
2023-09-06 04:38:52 -06:00
2022-12-29 19:11:06 -07:00
log_folder = os . path . join ( args . logdir , f " { args . project_name } _ { log_time } " )
2023-01-18 11:07:05 -07:00
2022-12-27 12:25:32 -07:00
if not os . path . exists ( log_folder ) :
os . makedirs ( log_folder )
2022-12-17 22:43:25 -07:00
2023-09-07 10:53:20 -06:00
def release_memory ( model_to_delete , original_device ) :
del model_to_delete
gc . collect ( )
if ' cuda ' in original_device . type :
torch . cuda . empty_cache ( )
2023-09-06 13:37:10 -06:00
2023-09-12 19:37:27 -06:00
use_ema_dacay_training = ( args . ema_decay_rate != None ) or ( args . ema_strength_target != None )
ema_model_loaded_from_file = False
2023-09-07 10:53:20 -06:00
if use_ema_dacay_training :
2023-09-12 19:37:27 -06:00
ema_device = torch . device ( args . ema_device )
2023-09-06 13:37:10 -06:00
2023-04-15 11:27:12 -06:00
optimizer_state_path = None
2023-09-10 13:37:47 -06:00
2023-01-12 14:32:37 -07:00
try :
2023-01-23 13:10:04 -07:00
# check for a local file
hf_cache_path = get_hf_ckpt_cache_path ( args . resume_ckpt )
if os . path . exists ( hf_cache_path ) or os . path . exists ( args . resume_ckpt ) :
2023-01-23 11:19:22 -07:00
model_root_folder , is_sd1attn , yaml = convert_to_hf ( args . resume_ckpt )
2023-02-20 14:00:25 -07:00
text_encoder = CLIPTextModel . from_pretrained ( model_root_folder , subfolder = " text_encoder " )
vae = AutoencoderKL . from_pretrained ( model_root_folder , subfolder = " vae " )
unet = UNet2DConditionModel . from_pretrained ( model_root_folder , subfolder = " unet " )
2023-01-23 13:10:04 -07:00
else :
# try to download from HF using resume_ckpt as a repo id
2023-02-20 14:00:25 -07:00
downloaded = try_download_model_from_hf ( repo_id = args . resume_ckpt )
if downloaded is None :
2023-01-23 13:10:04 -07:00
raise ValueError ( f " No local file/folder for { args . resume_ckpt } , and no matching huggingface.co repo could be downloaded " )
2023-02-20 14:00:25 -07:00
pipe , model_root_folder , is_sd1attn , yaml = downloaded
text_encoder = pipe . text_encoder
vae = pipe . vae
unet = pipe . unet
del pipe
2023-08-15 12:47:34 -06:00
2023-09-12 19:37:27 -06:00
if use_ema_dacay_training and args . ema_resume_model :
print ( f " Loading EMA model: { args . ema_resume_model } " )
ema_model_loaded_from_file = True
hf_cache_path = get_hf_ckpt_cache_path ( args . ema_resume_model )
2023-09-07 10:53:20 -06:00
2023-09-12 19:37:27 -06:00
if os . path . exists ( hf_cache_path ) or os . path . exists ( args . ema_resume_model ) :
2023-09-07 10:53:20 -06:00
ema_model_root_folder , ema_is_sd1attn , ema_yaml = convert_to_hf ( args . resume_ckpt )
text_encoder_ema = CLIPTextModel . from_pretrained ( ema_model_root_folder , subfolder = " text_encoder " )
unet_ema = UNet2DConditionModel . from_pretrained ( ema_model_root_folder , subfolder = " unet " )
else :
2023-09-12 19:37:27 -06:00
# try to download from HF using ema_resume_model as a repo id
ema_downloaded = try_download_model_from_hf ( repo_id = args . ema_resume_model )
2023-09-07 10:53:20 -06:00
if ema_downloaded is None :
raise ValueError (
2023-09-12 19:37:27 -06:00
f " No local file/folder for ema_resume_model { args . ema_resume_model } , and no matching huggingface.co repo could be downloaded " )
2023-09-07 10:53:20 -06:00
ema_pipe , ema_model_root_folder , ema_is_sd1attn , ema_yaml = ema_downloaded
text_encoder_ema = ema_pipe . text_encoder
unet_ema = ema_pipe . unet
del ema_pipe
# Make sure EMA model is on proper device, and memory released if moved
unet_ema_current_device = next ( unet_ema . parameters ( ) ) . device
if ema_device != unet_ema_current_device :
unet_ema_on_wrong_device = unet_ema
unet_ema = unet_ema . to ( ema_device )
release_memory ( unet_ema_on_wrong_device , unet_ema_current_device )
# Make sure EMA model is on proper device, and memory released if moved
text_encoder_ema_current_device = next ( text_encoder_ema . parameters ( ) ) . device
if ema_device != text_encoder_ema_current_device :
text_encoder_ema_on_wrong_device = text_encoder_ema
text_encoder_ema = text_encoder_ema . to ( ema_device )
release_memory ( text_encoder_ema_on_wrong_device , text_encoder_ema_current_device )
2023-09-06 13:37:10 -06:00
if args . enable_zero_terminal_snr :
# Use zero terminal SNR
2023-05-26 19:54:02 -06:00
from utils . unet_utils import enforce_zero_terminal_snr
2023-06-03 19:41:56 -06:00
temp_scheduler = DDIMScheduler . from_pretrained ( model_root_folder , subfolder = " scheduler " )
2023-06-03 15:17:04 -06:00
trained_betas = enforce_zero_terminal_snr ( temp_scheduler . betas ) . numpy ( ) . tolist ( )
2023-09-06 04:38:52 -06:00
inference_scheduler = DDIMScheduler . from_pretrained ( model_root_folder , subfolder = " scheduler " , trained_betas = trained_betas )
2023-06-03 19:41:56 -06:00
noise_scheduler = DDPMScheduler . from_pretrained ( model_root_folder , subfolder = " scheduler " , trained_betas = trained_betas )
2023-06-03 15:17:04 -06:00
else :
2023-09-06 04:38:52 -06:00
inference_scheduler = DDIMScheduler . from_pretrained ( model_root_folder , subfolder = " scheduler " )
2023-06-03 15:17:04 -06:00
noise_scheduler = DDPMScheduler . from_pretrained ( model_root_folder , subfolder = " scheduler " )
2023-05-26 19:54:02 -06:00
2023-01-12 14:32:37 -07:00
tokenizer = CLIPTokenizer . from_pretrained ( model_root_folder , subfolder = " tokenizer " , use_fast = False )
2023-02-20 14:00:25 -07:00
2023-01-12 14:32:37 -07:00
except Exception as e :
traceback . print_exc ( )
2023-01-25 03:06:52 -07:00
logging . error ( " * Failed to load checkpoint * " )
2023-02-08 06:15:54 -07:00
raise
2022-12-17 20:32:48 -07:00
2023-01-01 08:45:18 -07:00
if args . gradient_checkpointing :
unet . enable_gradient_checkpointing ( )
text_encoder . gradient_checkpointing_enable ( )
2023-02-25 13:05:22 -07:00
2023-01-20 07:42:24 -07:00
if not args . disable_xformers :
if ( args . amp and is_sd1attn ) or ( not is_sd1attn ) :
try :
unet . enable_xformers_memory_efficient_attention ( )
logging . info ( " Enabled xformers " )
except Exception as ex :
2023-01-21 23:15:50 -07:00
logging . warning ( " failed to load xformers, using attention slicing instead " )
unet . set_attention_slice ( " auto " )
2023-01-20 07:42:24 -07:00
pass
2023-04-16 16:48:44 -06:00
elif ( not args . amp and is_sd1attn ) :
logging . info ( " AMP is disabled but model is SD1.X, using attention slicing instead of xformers " )
unet . set_attention_slice ( " auto " )
2023-01-02 15:33:31 -07:00
else :
2023-04-16 16:48:44 -06:00
logging . info ( " xformers disabled via arg, using attention slicing instead " )
2023-01-21 23:15:50 -07:00
unet . set_attention_slice ( " auto " )
2022-12-17 20:32:48 -07:00
2023-01-16 13:48:06 -07:00
vae = vae . to ( device , dtype = torch . float16 if args . amp else torch . float32 )
2023-01-08 16:52:39 -07:00
unet = unet . to ( device , dtype = torch . float32 )
2023-01-16 13:48:06 -07:00
if args . disable_textenc_training and args . amp :
text_encoder = text_encoder . to ( device , dtype = torch . float16 )
else :
text_encoder = text_encoder . to ( device , dtype = torch . float32 )
2022-12-17 20:32:48 -07:00
2023-09-06 04:38:52 -06:00
2023-09-06 13:37:10 -06:00
2023-09-07 10:53:20 -06:00
if use_ema_dacay_training :
2023-09-12 19:37:27 -06:00
if not ema_model_loaded_from_file :
2023-09-07 10:53:20 -06:00
logging . info ( f " EMA decay enabled, creating EMA model. " )
2023-09-06 13:37:10 -06:00
2023-09-07 10:53:20 -06:00
with torch . no_grad ( ) :
2023-09-12 19:37:27 -06:00
if args . ema_device == device :
2023-09-07 10:53:20 -06:00
unet_ema = deepcopy ( unet )
text_encoder_ema = deepcopy ( text_encoder )
else :
unet_ema_first = deepcopy ( unet )
text_encoder_ema_first = deepcopy ( text_encoder )
unet_ema = unet_ema_first . to ( ema_device , dtype = unet . dtype )
text_encoder_ema = text_encoder_ema_first . to ( ema_device , dtype = text_encoder . dtype )
del unet_ema_first
del text_encoder_ema_first
else :
# Make sure correct types are used for models
unet_ema = unet_ema . to ( ema_device , dtype = unet . dtype )
text_encoder_ema = text_encoder_ema . to ( ema_device , dtype = text_encoder . dtype )
2023-09-10 13:37:47 -06:00
else :
unet_ema = None
text_encoder_ema = None
2023-09-06 04:38:52 -06:00
2023-06-03 09:26:53 -06:00
try :
2023-06-27 18:53:48 -06:00
#unet = torch.compile(unet)
#text_encoder = torch.compile(text_encoder)
#vae = torch.compile(vae)
torch . set_float32_matmul_precision ( ' high ' )
torch . backends . cudnn . allow_tf32 = True
#logging.info("Successfully compiled models")
2023-06-03 09:26:53 -06:00
except Exception as ex :
logging . warning ( f " Failed to compile model, continuing anyway, ex: { ex } " )
pass
2023-02-26 17:11:42 -07:00
optimizer_config = None
2023-04-29 20:56:10 -06:00
optimizer_config_path = args . optimizer_config if args . optimizer_config else " optimizer.json "
2023-02-26 17:11:42 -07:00
if os . path . exists ( os . path . join ( os . curdir , optimizer_config_path ) ) :
with open ( os . path . join ( os . curdir , optimizer_config_path ) , " r " ) as f :
optimizer_config = json . load ( f )
2023-03-25 18:09:06 -06:00
if args . wandb :
wandb . tensorboard . patch ( root_logdir = log_folder , pytorch = False , tensorboard_x = False , save = False )
wandb_run = wandb . init (
project = args . project_name ,
config = { " main_cfg " : vars ( args ) , " optimizer_cfg " : optimizer_config } ,
name = args . run_name ,
)
try :
if webbrowser . get ( ) :
webbrowser . open ( wandb_run . url , new = 2 )
except Exception :
pass
2023-02-25 13:05:22 -07:00
2023-02-07 09:32:54 -07:00
log_writer = SummaryWriter ( log_dir = log_folder ,
2023-03-25 18:09:06 -06:00
flush_secs = 20 ,
comment = args . run_name if args . run_name is not None else log_time ,
2023-02-25 13:05:22 -07:00
)
2023-02-07 09:32:54 -07:00
2023-04-29 07:33:01 -06:00
image_train_items = resolve_image_train_items ( args )
2023-01-29 19:11:34 -07:00
2023-02-08 11:34:49 -07:00
validator = None
2023-02-08 11:04:12 -07:00
if args . validation_config is not None :
validator = EveryDreamValidator ( args . validation_config ,
default_batch_size = args . batch_size ,
resolution = args . resolution ,
log_writer = log_writer ,
)
# the validation dataset may need to steal some items from image_train_items
image_train_items = validator . prepare_validation_splits ( image_train_items , tokenizer = tokenizer )
2023-02-07 09:32:54 -07:00
2023-04-19 03:06:02 -06:00
report_image_train_item_problems ( log_folder , image_train_items , batch_size = args . batch_size )
2023-01-29 19:11:34 -07:00
data_loader = DataLoaderMultiAspect (
image_train_items = image_train_items ,
seed = seed ,
batch_size = args . batch_size ,
2023-06-07 10:07:37 -06:00
grad_accum = args . grad_accum
2023-01-29 19:11:34 -07:00
)
2023-01-15 20:07:37 -07:00
2022-12-17 20:32:48 -07:00
train_batch = EveryDreamBatch (
2023-01-29 19:11:34 -07:00
data_loader = data_loader ,
2022-12-17 20:32:48 -07:00
debug_level = 1 ,
2022-12-17 22:43:25 -07:00
conditional_dropout = args . cond_dropout ,
2022-12-17 20:32:48 -07:00
tokenizer = tokenizer ,
2022-12-27 12:25:32 -07:00
seed = seed ,
2023-01-06 17:12:52 -07:00
shuffle_tags = args . shuffle_tags ,
2023-01-14 06:00:30 -07:00
rated_dataset = args . rated_dataset ,
rated_dataset_dropout_target = ( 1.0 - ( args . rated_dataset_target_dropout_percent / 100.0 ) )
2022-12-17 20:32:48 -07:00
)
2023-04-13 21:31:22 -06:00
2022-12-18 11:03:44 -07:00
torch . cuda . benchmark = False
2022-12-17 20:32:48 -07:00
2022-12-17 22:43:25 -07:00
epoch_len = math . ceil ( len ( train_batch ) / args . batch_size )
2023-09-06 13:37:10 -06:00
if use_ema_dacay_training :
2023-09-12 19:37:27 -06:00
args . ema_update_interval = args . ema_update_interval * args . grad_accum
if args . ema_strength_target != None :
2023-09-06 13:37:10 -06:00
total_number_of_steps : float = epoch_len * args . max_epochs
2023-09-12 19:37:27 -06:00
total_number_of_ema_update : float = total_number_of_steps / args . ema_update_interval
args . ema_decay_rate = args . ema_strength_target * * ( 1 / total_number_of_ema_update )
2023-09-06 13:37:10 -06:00
2023-09-12 19:37:27 -06:00
logging . info ( f " ema_strength_target is { args . ema_strength_target } , calculated ema_decay_rate will be: { args . ema_decay_rate } . " )
2023-09-06 13:37:10 -06:00
logging . info (
2023-09-12 19:37:27 -06:00
f " EMA decay enabled, with ema_decay_rate { args . ema_decay_rate } , ema_update_interval: { args . ema_update_interval } , ema_device: { args . ema_device } . " )
2023-09-06 13:37:10 -06:00
2023-05-14 03:49:11 -06:00
ed_optimizer = EveryDreamOptimizer ( args ,
optimizer_config ,
text_encoder ,
unet ,
epoch_len )
2022-12-17 22:43:25 -07:00
2022-12-17 20:32:48 -07:00
log_args ( log_writer , args )
2023-02-18 07:51:50 -07:00
sample_generator = SampleGenerator ( log_folder = log_folder , log_writer = log_writer ,
default_resolution = args . resolution , default_seed = args . seed ,
config_file_path = args . sample_prompts ,
2023-02-18 07:54:41 -07:00
batch_size = max ( 1 , args . batch_size / / 2 ) ,
2023-03-02 05:03:50 -07:00
default_sample_steps = args . sample_steps ,
2023-04-14 11:12:06 -06:00
use_xformers = is_xformers_available ( ) and not args . disable_xformers ,
2023-08-15 12:47:34 -06:00
use_penultimate_clip_layer = ( args . clip_skip > = 2 ) ,
2023-09-06 13:37:10 -06:00
guidance_rescale = 0.7 if args . enable_zero_terminal_snr else 0
2023-04-14 11:12:06 -06:00
)
2022-12-17 20:32:48 -07:00
"""
Train the model
"""
print ( f " { Fore . LIGHTGREEN_EX } ** Welcome to EveryDream trainer 2.0!** { Style . RESET_ALL } " )
2023-01-22 19:43:03 -07:00
print ( f " (C) 2022-2023 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html " )
2022-12-17 20:32:48 -07:00
print ( )
print ( " ** Trainer Starting ** " )
global interrupted
interrupted = False
def sigterm_handler ( signum , frame ) :
"""
handles sigterm
"""
2023-01-30 03:32:39 -07:00
is_main_thread = ( torch . utils . data . get_worker_info ( ) == None )
if is_main_thread :
2023-01-30 03:26:11 -07:00
global interrupted
if not interrupted :
interrupted = True
global global_step
#TODO: save model on ctrl-c
interrupted_checkpoint_path = os . path . join ( f " { log_folder } /ckpts/interrupted-gs { global_step } " )
print ( )
logging . error ( f " { Fore . LIGHTRED_EX } ************************************************************************ { Style . RESET_ALL } " )
logging . error ( f " { Fore . LIGHTRED_EX } CTRL-C received, attempting to save model to { interrupted_checkpoint_path } { Style . RESET_ALL } " )
logging . error ( f " { Fore . LIGHTRED_EX } ************************************************************************ { Style . RESET_ALL } " )
time . sleep ( 2 ) # give opportunity to ctrl-C again to cancel save
2023-09-10 13:37:47 -06:00
save_model ( interrupted_checkpoint_path , global_step = global_step , ed_state = make_current_ed_state ( ) ,
save_ckpt_dir = args . save_ckpt_dir , yaml_name = yaml , save_full_precision = args . save_full_precision ,
save_optimizer_flag = args . save_optimizer , save_ckpt = not args . no_save_ckpt )
2023-01-30 06:19:18 -07:00
exit ( _SIGTERM_EXIT_CODE )
else :
# non-main threads (i.e. dataloader workers) should exit cleanly
exit ( 0 )
2022-12-17 20:32:48 -07:00
signal . signal ( signal . SIGINT , sigterm_handler )
2023-04-13 21:31:22 -06:00
2022-12-17 20:32:48 -07:00
if not os . path . exists ( f " { log_folder } /samples/ " ) :
os . makedirs ( f " { log_folder } /samples/ " )
2023-02-08 05:46:58 -07:00
if gpu is not None :
gpu_used_mem , gpu_total_mem = gpu . get_gpu_memory ( )
logging . info ( f " Pretraining GPU Memory: { gpu_used_mem } / { gpu_total_mem } MB " )
2022-12-17 20:32:48 -07:00
logging . info ( f " saving ckpts every { args . ckpt_every_n_minutes } minutes " )
2022-12-18 18:52:58 -07:00
logging . info ( f " saving ckpts every { args . save_every_n_epochs } epochs " )
2022-12-17 20:32:48 -07:00
2023-02-07 09:32:54 -07:00
train_dataloader = build_torch_dataloader ( train_batch , batch_size = args . batch_size )
2022-12-17 20:32:48 -07:00
2023-01-15 20:07:37 -07:00
unet . train ( ) if not args . disable_unet_training else unet . eval ( )
2023-04-13 21:31:22 -06:00
text_encoder . train ( ) if not args . disable_textenc_training else text_encoder . eval ( )
2022-12-17 20:32:48 -07:00
logging . info ( f " unet device: { unet . device } , precision: { unet . dtype } , training: { unet . training } " )
logging . info ( f " text_encoder device: { text_encoder . device } , precision: { text_encoder . dtype } , training: { text_encoder . training } " )
logging . info ( f " vae device: { vae . device } , precision: { vae . dtype } , training: { vae . training } " )
2023-01-06 14:33:33 -07:00
logging . info ( f " scheduler: { noise_scheduler . __class__ } " )
2022-12-17 20:32:48 -07:00
logging . info ( f " { Fore . GREEN } Project name: { Style . RESET_ALL } { Fore . LIGHTGREEN_EX } { args . project_name } { Style . RESET_ALL } " )
2023-04-13 21:31:22 -06:00
logging . info ( f " { Fore . GREEN } grad_accum: { Style . RESET_ALL } { Fore . LIGHTGREEN_EX } { args . grad_accum } { Style . RESET_ALL } " ) ,
2022-12-17 20:32:48 -07:00
logging . info ( f " { Fore . GREEN } batch_size: { Style . RESET_ALL } { Fore . LIGHTGREEN_EX } { args . batch_size } { Style . RESET_ALL } " )
logging . info ( f " { Fore . GREEN } epoch_len: { Fore . LIGHTGREEN_EX } { epoch_len } { Style . RESET_ALL } " )
2023-03-02 10:29:28 -07:00
epoch_pbar = tqdm ( range ( args . max_epochs ) , position = 0 , leave = True , dynamic_ncols = True )
2022-12-17 20:32:48 -07:00
epoch_pbar . set_description ( f " { Fore . LIGHTCYAN_EX } Epochs { Style . RESET_ALL } " )
epoch_times = [ ]
global global_step
global_step = 0
training_start_time = time . time ( )
last_epoch_saved_time = training_start_time
append_epoch_log ( global_step = global_step , epoch_pbar = epoch_pbar , gpu = gpu , log_writer = log_writer )
2023-01-08 16:52:39 -07:00
loss_log_step = [ ]
2023-01-20 07:42:24 -07:00
assert len ( train_batch ) > 0 , " train_batch is empty, check that your data_root is correct "
2023-02-07 09:32:54 -07:00
# actual prediction function - shared between train and validate
2023-09-06 04:38:52 -06:00
def get_model_prediction_and_target ( image , tokens , zero_frequency_noise_ratio = 0.0 , return_loss = False ) :
2023-02-07 09:32:54 -07:00
with torch . no_grad ( ) :
with autocast ( enabled = args . amp ) :
pixel_values = image . to ( memory_format = torch . contiguous_format ) . to ( unet . device )
latents = vae . encode ( pixel_values , return_dict = False )
del pixel_values
latents = latents [ 0 ] . sample ( ) * 0.18215
2023-09-06 13:37:10 -06:00
if zero_frequency_noise_ratio != None :
if zero_frequency_noise_ratio < 0 :
zero_frequency_noise_ratio = 0
2023-02-18 09:23:02 -07:00
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
2023-02-15 16:53:08 -07:00
zero_frequency_noise = zero_frequency_noise_ratio * torch . randn ( latents . shape [ 0 ] , latents . shape [ 1 ] , 1 , 1 , device = latents . device )
2023-04-13 21:31:22 -06:00
noise = torch . randn_like ( latents ) + zero_frequency_noise
2023-02-07 09:32:54 -07:00
bsz = latents . shape [ 0 ]
timesteps = torch . randint ( 0 , noise_scheduler . config . num_train_timesteps , ( bsz , ) , device = latents . device )
timesteps = timesteps . long ( )
cuda_caption = tokens . to ( text_encoder . device )
encoder_hidden_states = text_encoder ( cuda_caption , output_hidden_states = True )
if args . clip_skip > 0 :
encoder_hidden_states = text_encoder . text_model . final_layer_norm (
encoder_hidden_states . hidden_states [ - args . clip_skip ] )
else :
encoder_hidden_states = encoder_hidden_states . last_hidden_state
noisy_latents = noise_scheduler . add_noise ( latents , noise , timesteps )
if noise_scheduler . config . prediction_type == " epsilon " :
target = noise
elif noise_scheduler . config . prediction_type in [ " v_prediction " , " v-prediction " ] :
target = noise_scheduler . get_velocity ( latents , noise , timesteps )
else :
raise ValueError ( f " Unknown prediction type { noise_scheduler . config . prediction_type } " )
del noise , latents , cuda_caption
with autocast ( enabled = args . amp ) :
2023-04-30 21:45:13 -06:00
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
2023-02-07 09:32:54 -07:00
model_pred = unet ( noisy_latents , timesteps , encoder_hidden_states ) . sample
2023-09-06 04:38:52 -06:00
if return_loss :
if args . min_snr_gamma is None :
loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " mean " )
else :
snr = compute_snr ( timesteps , noise_scheduler )
mse_loss_weights = (
torch . stack (
[ snr , args . min_snr_gamma * torch . ones_like ( timesteps ) ] , dim = 1
) . min ( dim = 1 ) [ 0 ]
/ snr
)
mse_loss_weights [ snr == 0 ] = 1.0
loss = F . mse_loss ( model_pred . float ( ) , target . float ( ) , reduction = " none " )
loss = loss . mean ( dim = list ( range ( 1 , len ( loss . shape ) ) ) ) * mse_loss_weights
loss = loss . mean ( )
return model_pred , target , loss
else :
return model_pred , target
2023-02-07 09:32:54 -07:00
2023-03-02 05:03:50 -07:00
def generate_samples ( global_step : int , batch ) :
2023-09-06 13:37:10 -06:00
nonlocal unet
nonlocal text_encoder
nonlocal unet_ema
nonlocal text_encoder_ema
2023-03-02 05:03:50 -07:00
with isolate_rng ( ) :
2023-03-02 10:29:28 -07:00
prev_sample_steps = sample_generator . sample_steps
2023-03-02 05:03:50 -07:00
sample_generator . reload_config ( )
2023-03-02 10:29:28 -07:00
if prev_sample_steps != sample_generator . sample_steps :
next_sample_step = math . ceil ( ( global_step + 1 ) / sample_generator . sample_steps ) * sample_generator . sample_steps
print ( f " * SampleGenerator config changed, now generating images samples every " +
f " { sample_generator . sample_steps } training steps (next= { next_sample_step } ) " )
2023-03-02 05:03:50 -07:00
sample_generator . update_random_captions ( batch [ " captions " ] )
2023-09-06 13:37:10 -06:00
models_info = [ ]
2023-09-18 14:13:22 -06:00
if ( args . ema_decay_rate is None ) or args . ema_sample_nonema_model :
2023-09-06 13:37:10 -06:00
models_info . append ( { " is_ema " : False , " swap_required " : False } )
2023-09-12 19:37:27 -06:00
if ( args . ema_decay_rate is not None ) and args . ema_sample_ema_model :
2023-09-07 10:53:20 -06:00
models_info . append ( { " is_ema " : True , " swap_required " : ema_device != device } )
2023-09-06 13:37:10 -06:00
for model_info in models_info :
extra_info : str = " "
if model_info [ " is_ema " ] :
current_unet , current_text_encoder = unet_ema , text_encoder_ema
2023-09-10 15:13:26 -06:00
extra_info = " _ema "
2023-09-06 13:37:10 -06:00
else :
current_unet , current_text_encoder = unet , text_encoder
2023-09-10 15:13:26 -06:00
torch . cuda . empty_cache ( )
2023-09-06 13:37:10 -06:00
if model_info [ " swap_required " ] :
with torch . no_grad ( ) :
2023-09-07 10:53:20 -06:00
unet_unloaded = unet . to ( ema_device )
2023-09-06 13:37:10 -06:00
del unet
2023-09-07 10:53:20 -06:00
text_encoder_unloaded = text_encoder . to ( ema_device )
2023-09-06 13:37:10 -06:00
del text_encoder
current_unet = unet_ema . to ( device )
del unet_ema
current_text_encoder = text_encoder_ema . to ( device )
del text_encoder_ema
gc . collect ( )
2023-09-10 13:42:01 -06:00
torch . cuda . empty_cache ( )
2023-09-06 13:37:10 -06:00
inference_pipe = sample_generator . create_inference_pipe ( unet = current_unet ,
text_encoder = current_text_encoder ,
tokenizer = tokenizer ,
vae = vae ,
diffusers_scheduler_config = inference_scheduler . config
) . to ( device )
sample_generator . generate_samples ( inference_pipe , global_step , extra_info = extra_info )
# Cleanup
del inference_pipe
if model_info [ " swap_required " ] :
with torch . no_grad ( ) :
unet = unet_unloaded . to ( device )
del unet_unloaded
text_encoder = text_encoder_unloaded . to ( device )
del text_encoder_unloaded
2023-09-07 10:53:20 -06:00
unet_ema = current_unet . to ( ema_device )
2023-09-06 13:37:10 -06:00
del current_unet
2023-09-07 10:53:20 -06:00
text_encoder_ema = current_text_encoder . to ( ema_device )
2023-09-06 13:37:10 -06:00
del current_text_encoder
gc . collect ( )
2023-09-10 13:42:01 -06:00
torch . cuda . empty_cache ( )
2023-03-02 05:03:50 -07:00
2023-06-17 11:18:04 -06:00
def make_save_path ( epoch , global_step , prepend = " " ) :
2023-09-10 13:37:47 -06:00
basename = f " { prepend } { args . project_name } "
if epoch is not None :
basename + = f " -ep { epoch : 02 } "
if global_step is not None :
basename + = f " -gs { global_step : 05 } "
return os . path . join ( log_folder , " ckpts " , basename )
2023-06-17 11:18:04 -06:00
2023-09-06 13:37:10 -06:00
2023-02-17 11:35:25 -07:00
# Pre-train validation to establish a starting point on the loss graph
if validator :
2023-03-09 12:42:17 -07:00
validator . do_validation ( global_step = 0 ,
get_model_prediction_and_target_callable = get_model_prediction_and_target )
2023-02-07 09:32:54 -07:00
2023-03-02 05:03:50 -07:00
# the sample generator might be configured to generate samples before step 0
if sample_generator . generate_pretrain_samples :
_ , batch = next ( enumerate ( train_dataloader ) )
generate_samples ( global_step = 0 , batch = batch )
2023-06-29 18:44:14 -06:00
from plugins . plugins import load_plugin
2023-06-29 20:00:16 -06:00
if args . plugins is not None :
plugins = [ load_plugin ( name ) for name in args . plugins ]
else :
2023-07-05 14:02:18 -06:00
logging . info ( " No plugins specified " )
2023-06-29 20:00:16 -06:00
plugins = [ ]
2023-09-06 04:38:52 -06:00
2023-07-04 15:29:22 -06:00
from plugins . plugins import PluginRunner
plugin_runner = PluginRunner ( plugins = plugins )
2023-06-27 18:53:48 -06:00
2023-09-10 13:37:47 -06:00
def make_current_ed_state ( ) - > EveryDreamTrainingState :
return EveryDreamTrainingState ( optimizer = ed_optimizer ,
train_batch = train_batch ,
unet = unet ,
text_encoder = text_encoder ,
tokenizer = tokenizer ,
scheduler = noise_scheduler ,
vae = vae ,
unet_ema = unet_ema ,
text_encoder_ema = text_encoder_ema )
epoch = None
2023-01-08 16:52:39 -07:00
try :
2023-02-26 17:11:42 -07:00
write_batch_schedule ( args , log_folder , train_batch , epoch = 0 )
2023-09-10 13:37:47 -06:00
plugin_runner . run_on_training_start ( log_folder = log_folder , project_name = args . project_name )
2023-04-13 21:31:22 -06:00
2022-12-17 20:32:48 -07:00
for epoch in range ( args . max_epochs ) :
2023-09-06 04:38:52 -06:00
if args . load_settings_every_epoch :
2023-09-06 13:37:10 -06:00
load_train_json_from_file ( args )
2023-09-06 04:38:52 -06:00
2023-09-10 13:37:47 -06:00
epoch_len = math . ceil ( len ( train_batch ) / args . batch_size )
2023-09-06 04:38:52 -06:00
2023-09-10 13:37:47 -06:00
plugin_runner . run_on_epoch_start (
epoch = epoch ,
global_step = global_step ,
epoch_length = epoch_len ,
project_name = args . project_name ,
log_folder = log_folder ,
data_root = args . data_root
)
2023-06-27 18:53:48 -06:00
2023-01-08 16:52:39 -07:00
loss_epoch = [ ]
2022-12-17 20:32:48 -07:00
epoch_start_time = time . time ( )
2023-01-03 13:32:58 -07:00
images_per_sec_log_step = [ ]
2022-12-17 20:32:48 -07:00
2023-03-02 10:29:28 -07:00
steps_pbar = tqdm ( range ( epoch_len ) , position = 1 , leave = False , dynamic_ncols = True )
2023-01-14 06:00:30 -07:00
steps_pbar . set_description ( f " { Fore . LIGHTCYAN_EX } Steps { Style . RESET_ALL } " )
2023-03-09 12:42:17 -07:00
validation_steps = (
[ ] if validator is None
2023-03-10 15:13:53 -07:00
else validator . get_validation_step_indices ( epoch , len ( train_dataloader ) )
2023-03-09 12:42:17 -07:00
)
2022-12-27 12:25:32 -07:00
for step , batch in enumerate ( train_dataloader ) :
2023-09-10 13:37:47 -06:00
2022-12-17 20:32:48 -07:00
step_start_time = time . time ( )
2023-09-18 14:13:22 -06:00
2023-07-04 15:29:22 -06:00
plugin_runner . run_on_step_start ( epoch = epoch ,
2023-09-10 13:37:47 -06:00
local_step = step ,
2023-07-04 15:29:22 -06:00
global_step = global_step ,
project_name = args . project_name ,
log_folder = log_folder ,
2023-09-10 13:37:47 -06:00
batch = batch ,
ed_state = make_current_ed_state ( ) )
2022-12-17 20:32:48 -07:00
2023-09-06 04:38:52 -06:00
model_pred , target , loss = get_model_prediction_and_target ( batch [ " image " ] , batch [ " tokens " ] , args . zero_frequency_noise_ratio , return_loss = True )
2023-01-06 14:33:33 -07:00
2023-01-03 13:17:24 -07:00
del target , model_pred
2023-02-13 10:18:33 -07:00
if batch [ " runt_size " ] > 0 :
2023-06-27 18:53:48 -06:00
loss_scale = ( batch [ " runt_size " ] / args . batch_size ) * * 1.5 # further discount runts by **1.5
2023-02-13 10:18:33 -07:00
loss = loss * loss_scale
2023-04-30 21:45:13 -06:00
ed_optimizer . step ( loss , step , global_step )
2022-12-17 20:32:48 -07:00
2023-09-06 04:38:52 -06:00
if args . ema_decay_rate != None :
2023-09-12 19:37:27 -06:00
if ( ( global_step + 1 ) % args . ema_update_interval ) == 0 :
2023-09-07 10:53:20 -06:00
# debug_start_time = time.time() # Measure time
2023-09-06 13:37:10 -06:00
if args . disable_unet_training != True :
2023-09-07 10:53:20 -06:00
update_ema ( unet , unet_ema , args . ema_decay_rate , default_device = device , ema_device = ema_device )
2023-09-06 13:37:10 -06:00
if args . disable_textenc_training != True :
2023-09-07 10:53:20 -06:00
update_ema ( text_encoder , text_encoder_ema , args . ema_decay_rate , default_device = device , ema_device = ema_device )
2023-09-06 13:37:10 -06:00
2023-09-07 10:53:20 -06:00
# debug_end_time = time.time() # Measure time
# debug_elapsed_time = debug_end_time - debug_start_time # Measure time
# print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time
2023-09-06 04:38:52 -06:00
2023-01-08 16:52:39 -07:00
loss_step = loss . detach ( ) . item ( )
2023-02-13 10:18:33 -07:00
steps_pbar . set_postfix ( { " loss/step " : loss_step } , { " gs " : global_step } )
2022-12-17 20:32:48 -07:00
steps_pbar . update ( 1 )
images_per_sec = args . batch_size / ( time . time ( ) - step_start_time )
2023-01-03 13:32:58 -07:00
images_per_sec_log_step . append ( images_per_sec )
2022-12-17 20:32:48 -07:00
2023-01-08 16:52:39 -07:00
loss_log_step . append ( loss_step )
loss_epoch . append ( loss_step )
2022-12-17 20:32:48 -07:00
if ( global_step + 1 ) % args . log_step == 0 :
2023-06-27 18:53:48 -06:00
loss_step = sum ( loss_log_step ) / len ( loss_log_step )
2023-04-29 20:56:10 -06:00
lr_unet = ed_optimizer . get_unet_lr ( )
lr_textenc = ed_optimizer . get_textenc_lr ( )
2023-01-08 16:52:39 -07:00
loss_log_step = [ ]
2023-09-06 04:38:52 -06:00
2023-06-04 09:13:37 -06:00
log_writer . add_scalar ( tag = " hyperparameter/lr unet " , scalar_value = lr_unet , global_step = global_step )
log_writer . add_scalar ( tag = " hyperparameter/lr text encoder " , scalar_value = lr_textenc , global_step = global_step )
2023-06-27 18:53:48 -06:00
log_writer . add_scalar ( tag = " loss/log_step " , scalar_value = loss_step , global_step = global_step )
2023-04-29 20:56:10 -06:00
2023-01-03 13:32:58 -07:00
sum_img = sum ( images_per_sec_log_step )
avg = sum_img / len ( images_per_sec_log_step )
images_per_sec_log_step = [ ]
2023-01-08 16:52:39 -07:00
if args . amp :
2023-06-04 09:13:37 -06:00
log_writer . add_scalar ( tag = " hyperparameter/grad scale " , scalar_value = ed_optimizer . get_scale ( ) , global_step = global_step )
2022-12-17 20:32:48 -07:00
log_writer . add_scalar ( tag = " performance/images per second " , scalar_value = avg , global_step = global_step )
2023-04-29 20:56:10 -06:00
2023-06-27 18:53:48 -06:00
logs = { " loss/log_step " : loss_step , " lr_unet " : lr_unet , " lr_te " : lr_textenc , " img/s " : images_per_sec }
2022-12-17 20:32:48 -07:00
append_epoch_log ( global_step = global_step , epoch_pbar = epoch_pbar , gpu = gpu , log_writer = log_writer , * * logs )
2023-01-03 12:27:26 -07:00
torch . cuda . empty_cache ( )
2022-12-17 20:32:48 -07:00
2023-03-09 12:42:17 -07:00
if validator and step in validation_steps :
validator . do_validation ( global_step , get_model_prediction_and_target )
2023-03-02 05:03:50 -07:00
if ( global_step + 1 ) % sample_generator . sample_steps == 0 :
generate_samples ( global_step = global_step , batch = batch )
2022-12-17 20:32:48 -07:00
2023-01-02 15:33:31 -07:00
min_since_last_ckpt = ( time . time ( ) - last_epoch_saved_time ) / 60
2022-12-17 20:32:48 -07:00
2023-09-10 13:37:47 -06:00
needs_save = False
2022-12-17 20:32:48 -07:00
if args . ckpt_every_n_minutes is not None and ( min_since_last_ckpt > args . ckpt_every_n_minutes ) :
last_epoch_saved_time = time . time ( )
2023-01-01 08:45:18 -07:00
logging . info ( f " Saving model, { args . ckpt_every_n_minutes } mins at step { global_step } " )
2023-09-10 13:37:47 -06:00
needs_save = True
2023-02-18 09:23:02 -07:00
if epoch > 0 and epoch % args . save_every_n_epochs == 0 and step == 0 and epoch < args . max_epochs - 1 and epoch > = args . save_ckpts_from_n_epochs :
2023-01-01 08:45:18 -07:00
logging . info ( f " Saving model, { args . save_every_n_epochs } epochs at step { global_step } " )
2023-09-10 13:37:47 -06:00
needs_save = True
if needs_save :
2023-06-17 11:18:04 -06:00
save_path = make_save_path ( epoch , global_step )
2023-09-10 13:37:47 -06:00
save_model ( save_path , global_step = global_step , ed_state = make_current_ed_state ( ) ,
save_ckpt_dir = None , yaml_name = None ,
save_full_precision = args . save_full_precision ,
save_optimizer_flag = args . save_optimizer , save_ckpt = False )
2022-12-17 20:32:48 -07:00
2023-07-04 15:29:22 -06:00
plugin_runner . run_on_step_end ( epoch = epoch ,
global_step = global_step ,
2023-09-10 13:37:47 -06:00
local_step = step ,
2023-07-04 15:29:22 -06:00
project_name = args . project_name ,
log_folder = log_folder ,
data_root = args . data_root ,
2023-09-10 13:37:47 -06:00
batch = batch ,
ed_state = make_current_ed_state ( ) )
2023-07-04 15:29:22 -06:00
2023-01-08 16:52:39 -07:00
del batch
global_step + = 1
2022-12-17 20:32:48 -07:00
# end of step
2023-01-14 06:00:30 -07:00
steps_pbar . close ( )
2023-01-16 13:48:06 -07:00
2023-01-01 08:45:18 -07:00
elapsed_epoch_time = ( time . time ( ) - epoch_start_time ) / 60
2022-12-17 20:32:48 -07:00
epoch_times . append ( dict ( epoch = epoch , time = elapsed_epoch_time ) )
log_writer . add_scalar ( " performance/minutes per epoch " , elapsed_epoch_time , global_step )
epoch_pbar . update ( 1 )
2023-01-01 08:45:18 -07:00
if epoch < args . max_epochs - 1 :
2023-01-14 06:00:30 -07:00
train_batch . shuffle ( epoch_n = epoch , max_epochs = args . max_epochs )
2023-01-29 19:28:07 -07:00
write_batch_schedule ( args , log_folder , train_batch , epoch + 1 )
2023-01-16 13:48:06 -07:00
2023-09-10 13:37:47 -06:00
if len ( loss_epoch ) > 0 :
loss_epoch = sum ( loss_epoch ) / len ( loss_epoch )
log_writer . add_scalar ( tag = " loss/epoch " , scalar_value = loss_epoch , global_step = global_step )
2023-02-07 09:32:54 -07:00
2023-07-04 15:29:22 -06:00
plugin_runner . run_on_epoch_end ( epoch = epoch ,
global_step = global_step ,
project_name = args . project_name ,
log_folder = log_folder ,
data_root = args . data_root )
2023-09-06 04:38:52 -06:00
gc . collect ( )
2022-12-19 00:07:14 -07:00
# end of epoch
2022-12-17 20:32:48 -07:00
# end of training
2023-06-17 11:18:04 -06:00
epoch = args . max_epochs
2023-09-10 13:37:47 -06:00
plugin_runner . run_on_training_end ( )
2023-06-17 11:18:04 -06:00
save_path = make_save_path ( epoch , global_step , prepend = ( " " if args . no_prepend_last else " last- " ) )
2023-09-10 13:37:47 -06:00
save_model ( save_path , global_step = global_step , ed_state = make_current_ed_state ( ) ,
save_ckpt_dir = args . save_ckpt_dir , yaml_name = yaml , save_full_precision = args . save_full_precision ,
save_optimizer_flag = args . save_optimizer , save_ckpt = not args . no_save_ckpt )
2022-12-17 20:32:48 -07:00
total_elapsed_time = time . time ( ) - training_start_time
logging . info ( f " { Fore . CYAN } Training complete { Style . RESET_ALL } " )
2023-01-03 12:27:26 -07:00
logging . info ( f " Total training time took { total_elapsed_time / 60 : .2f } minutes, total steps: { global_step } " )
2022-12-19 00:07:14 -07:00
logging . info ( f " Average epoch time: { np . mean ( [ t [ ' time ' ] for t in epoch_times ] ) : .2f } minutes " )
2022-12-17 20:32:48 -07:00
except Exception as ex :
logging . error ( f " { Fore . LIGHTYELLOW_EX } Something went wrong, attempting to save model { Style . RESET_ALL } " )
2023-06-17 11:18:04 -06:00
save_path = make_save_path ( epoch , global_step , prepend = " errored- " )
2023-09-10 13:37:47 -06:00
save_model ( save_path , global_step = global_step , ed_state = make_current_ed_state ( ) ,
save_ckpt_dir = args . save_ckpt_dir , yaml_name = yaml , save_full_precision = args . save_full_precision ,
save_optimizer_flag = args . save_optimizer , save_ckpt = not args . no_save_ckpt )
2023-07-05 14:02:18 -06:00
logging . info ( f " { Fore . LIGHTYELLOW_EX } Model saved, re-raising exception and exiting. Exception was: { Style . RESET_ALL } { Fore . LIGHTRED_EX } { ex } { Style . RESET_ALL } " )
2022-12-17 20:32:48 -07:00
raise ex
2022-12-18 19:16:14 -07:00
logging . info ( f " { Fore . LIGHTWHITE_EX } *************************** { Style . RESET_ALL } " )
logging . info ( f " { Fore . LIGHTWHITE_EX } **** Finished training **** { Style . RESET_ALL } " )
logging . info ( f " { Fore . LIGHTWHITE_EX } *************************** { Style . RESET_ALL } " )
2022-12-17 20:32:48 -07:00
2023-01-08 16:52:39 -07:00
2022-12-17 20:32:48 -07:00
if __name__ == " __main__ " :
2023-04-29 16:15:25 -06:00
check_git ( )
2023-03-25 18:09:06 -06:00
supported_resolutions = aspects . get_supported_resolutions ( )
2023-01-01 08:45:18 -07:00
argparser = argparse . ArgumentParser ( description = " EveryDream2 Training options " )
2023-01-03 12:27:26 -07:00
argparser . add_argument ( " --config " , type = str , required = False , default = None , help = " JSON config file to load options from " )
2023-02-05 23:40:59 -07:00
args , argv = argparser . parse_known_args ( )
2023-01-03 12:27:26 -07:00
2023-09-07 10:53:20 -06:00
load_train_json_from_file ( args , report_load = True )
2022-12-17 20:32:48 -07:00
2023-02-05 23:40:59 -07:00
argparser = argparse . ArgumentParser ( description = " EveryDream2 Training options " )
2023-04-30 07:28:55 -06:00
argparser . add_argument ( " --amp " , action = " store_true " , default = True , help = " deprecated, use --disable_amp if you wish to disable AMP " )
2023-02-05 23:40:59 -07:00
argparser . add_argument ( " --batch_size " , type = int , default = 2 , help = " Batch size (def: 2) " )
argparser . add_argument ( " --ckpt_every_n_minutes " , type = int , default = None , help = " Save checkpoint every n minutes, def: 20 " )
argparser . add_argument ( " --clip_grad_norm " , type = float , default = None , help = " Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan? " )
argparser . add_argument ( " --clip_skip " , type = int , default = 0 , help = " Train using penultimate layer (def: 0) (2 is ' penultimate ' ) " , choices = [ 0 , 1 , 2 , 3 , 4 ] )
argparser . add_argument ( " --cond_dropout " , type = float , default = 0.04 , help = " Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04) " )
argparser . add_argument ( " --data_root " , type = str , default = " input " , help = " folder where your training images are " )
2023-03-10 19:35:47 -07:00
argparser . add_argument ( " --disable_amp " , action = " store_true " , default = False , help = " disables training of text encoder (def: False) " )
2023-02-05 23:40:59 -07:00
argparser . add_argument ( " --disable_textenc_training " , action = " store_true " , default = False , help = " disables training of text encoder (def: False) " )
argparser . add_argument ( " --disable_unet_training " , action = " store_true " , default = False , help = " disables training of unet (def: False) NOT RECOMMENDED " )
argparser . add_argument ( " --disable_xformers " , action = " store_true " , default = False , help = " disable xformers, may reduce performance (def: False) " )
argparser . add_argument ( " --flip_p " , type = float , default = 0.0 , help = " probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces! " )
2023-09-12 19:37:27 -06:00
argparser . add_argument ( " --gpuid " , type = int , default = 0 , help = " id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids " )
2023-02-05 23:40:59 -07:00
argparser . add_argument ( " --gradient_checkpointing " , action = " store_true " , default = False , help = " enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False) " )
argparser . add_argument ( " --grad_accum " , type = int , default = 1 , help = " Gradient accumulation factor (def: 1), (ex, 2) " )
argparser . add_argument ( " --logdir " , type = str , default = " logs " , help = " folder to save logs to (def: logs) " )
argparser . add_argument ( " --log_step " , type = int , default = 25 , help = " How often to log training stats, def: 25, recommend default! " )
argparser . add_argument ( " --lowvram " , action = " store_true " , default = False , help = " automatically overrides various args to support 12GB gpu " )
argparser . add_argument ( " --lr " , type = float , default = None , help = " Learning rate, if using scheduler is maximum LR at top of curve " )
argparser . add_argument ( " --lr_decay_steps " , type = int , default = 0 , help = " Steps to reach minimum LR, default: automatically set " )
argparser . add_argument ( " --lr_scheduler " , type = str , default = " constant " , help = " LR scheduler, (default: constant) " , choices = [ " constant " , " linear " , " cosine " , " polynomial " ] )
argparser . add_argument ( " --lr_warmup_steps " , type = int , default = None , help = " Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant " )
argparser . add_argument ( " --max_epochs " , type = int , default = 300 , help = " Maximum number of epochs to train for " )
2023-06-17 11:18:04 -06:00
argparser . add_argument ( " --no_prepend_last " , action = " store_true " , help = " Do not prepend ' last- ' to the final checkpoint filename " )
argparser . add_argument ( " --no_save_ckpt " , action = " store_true " , help = " Save only diffusers files, no .ckpts " )
2023-02-25 13:05:22 -07:00
argparser . add_argument ( " --optimizer_config " , default = " optimizer.json " , help = " Path to a JSON configuration file for the optimizer. Default is ' optimizer.json ' " )
2023-06-27 18:53:48 -06:00
argparser . add_argument ( ' --plugins ' , nargs = ' + ' , help = ' Names of plugins to use ' )
2023-02-05 23:40:59 -07:00
argparser . add_argument ( " --project_name " , type = str , default = " myproj " , help = " Project name for logs and checkpoints, ex. ' tedbennett ' , ' superduperV1 ' " )
argparser . add_argument ( " --resolution " , type = int , default = 512 , help = " resolution to train " , choices = supported_resolutions )
argparser . add_argument ( " --resume_ckpt " , type = str , required = not ( ' resume_ckpt ' in args ) , default = " sd_v1-5_vae.ckpt " , help = " The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 " )
2023-02-25 13:05:22 -07:00
argparser . add_argument ( " --run_name " , type = str , required = False , default = None , help = " Run name for wandb (child of project name), and comment for tensorboard, (def: None) " )
2023-02-18 07:51:50 -07:00
argparser . add_argument ( " --sample_prompts " , type = str , default = " sample_prompts.txt " , help = " Text file with prompts to generate test samples from, or JSON file with sample generator settings (default: sample_prompts.txt) " )
2023-02-05 23:40:59 -07:00
argparser . add_argument ( " --sample_steps " , type = int , default = 250 , help = " Number of steps between samples (def: 250) " )
argparser . add_argument ( " --save_ckpt_dir " , type = str , default = None , help = " folder to save checkpoints to (def: root training folder) " )
argparser . add_argument ( " --save_every_n_epochs " , type = int , default = None , help = " Save checkpoint every n epochs, def: 0 (disabled) " )
2023-02-13 10:18:33 -07:00
argparser . add_argument ( " --save_ckpts_from_n_epochs " , type = int , default = 0 , help = " Only saves checkpoints starting an N epochs, def: 0 (disabled) " )
2023-02-05 23:40:59 -07:00
argparser . add_argument ( " --save_full_precision " , action = " store_true " , default = False , help = " save ckpts at full FP32 " )
argparser . add_argument ( " --save_optimizer " , action = " store_true " , default = False , help = " saves optimizer state with ckpt, useful for resuming training later " )
argparser . add_argument ( " --seed " , type = int , default = 555 , help = " seed used for samples and shuffling, use -1 for random " )
argparser . add_argument ( " --shuffle_tags " , action = " store_true " , default = False , help = " randomly shuffles CSV tags in captions, for booru datasets " )
2023-02-25 13:05:22 -07:00
argparser . add_argument ( " --useadam8bit " , action = " store_true " , default = False , help = " deprecated, use --optimizer_config and optimizer.json instead " )
2023-02-05 23:40:59 -07:00
argparser . add_argument ( " --wandb " , action = " store_true " , default = False , help = " enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY " )
2023-02-08 14:05:19 -07:00
argparser . add_argument ( " --validation_config " , default = None , help = " Path to a JSON configuration file for the validator. Default is no validation. " )
2023-02-05 23:40:59 -07:00
argparser . add_argument ( " --write_schedule " , action = " store_true " , default = False , help = " write schedule of images and their batches to file (def: False) " )
argparser . add_argument ( " --rated_dataset " , action = " store_true " , default = False , help = " enable rated image set training, to less often train on lower rated images through the epochs " )
argparser . add_argument ( " --rated_dataset_target_dropout_percent " , type = int , default = 50 , help = " how many images (in percent) should be included in the last epoch (Default 50) " )
2023-09-06 13:37:10 -06:00
argparser . add_argument ( " --zero_frequency_noise_ratio " , type = float , default = 0.02 , help = " adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15 " )
argparser . add_argument ( " --enable_zero_terminal_snr " , action = " store_true " , default = None , help = " Use zero terminal SNR noising beta schedule " )
2023-09-10 11:06:50 -06:00
argparser . add_argument ( " --load_settings_every_epoch " , action = " store_true " , default = None , help = " Will load ' train.json ' at start of every epoch. Disabled by default and enabled when used. " )
2023-09-12 19:37:27 -06:00
argparser . add_argument ( " --min_snr_gamma " , type = int , default = None , help = " min-SNR-gamma parameter is the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556 " )
argparser . add_argument ( " --ema_decay_rate " , type = float , default = None , help = " EMA decay rate. EMA model will be updated with (1 - ema_rate) from training, and the ema_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature. " )
argparser . add_argument ( " --ema_strength_target " , type = float , default = None , help = " EMA decay target value in range (0,1). emarate will be calculated from equation: ' ema_decay_rate=ema_strength_target^(total_steps/ema_update_interval) ' . Using this parameter will enable the ema feature and overide ema_decay_rate. " )
argparser . add_argument ( " --ema_update_interval " , type = int , default = 500 , help = " How many steps between optimizer steps that EMA decay updates. EMA model will be update on every step modulo grad_accum times ema_update_interval. " )
argparser . add_argument ( " --ema_device " , type = str , default = ' cpu ' , help = " EMA decay device values: cpu, cuda. Using ' cpu ' is taking around 4 seconds per update vs fraction of a second on ' cuda ' . Using ' cuda ' will reserve around 3.2GB VRAM for a model, with ' cpu ' the system RAM will be used. " )
argparser . add_argument ( " --ema_sample_nonema_model " , action = " store_true " , default = False , help = " Will show samples from non-EMA trained model, just like regular training. Can be used with: --ema_sample_ema_model " )
argparser . add_argument ( " --ema_sample_ema_model " , action = " store_true " , default = False , help = " Will show samples from EMA model. May be slower when using ema cpu offloading. Can be used with: --ema_sample_nonema_model " )
argparser . add_argument ( " --ema_resume_model " , type = str , default = None , help = " The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay " )
2023-09-10 11:06:50 -06:00
2023-02-05 23:40:59 -07:00
# load CLI args to overwrite existing config args
args = argparser . parse_args ( args = argv , namespace = args )
2023-09-06 04:38:52 -06:00
2022-12-17 20:32:48 -07:00
main ( args )