2023-05-04 18:11:11 -06:00
"""
Copyright [ 2022 - 2023 ] Victor C Hall
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2023-04-29 20:56:10 -06:00
import logging
import itertools
import os
2023-05-14 03:49:11 -06:00
from itertools import chain
from typing import Generator , Any
2023-04-29 20:56:10 -06:00
import torch
from torch . cuda . amp import autocast , GradScaler
from diffusers . optimization import get_scheduler
from colorama import Fore , Style
2023-04-30 21:45:13 -06:00
import pprint
2023-04-29 20:56:10 -06:00
BETAS_DEFAULT = [ 0.9 , 0.999 ]
EPSILON_DEFAULT = 1e-8
WEIGHT_DECAY_DEFAULT = 0.01
LR_DEFAULT = 1e-6
OPTIMIZER_TE_STATE_FILENAME = " optimizer_te.pt "
OPTIMIZER_UNET_STATE_FILENAME = " optimizer_unet.pt "
class EveryDreamOptimizer ( ) :
"""
Wrapper to manage optimizers
2023-04-30 21:45:13 -06:00
resume_ckpt_path : path to resume checkpoint , will try to load state ( . pt ) files if they exist
optimizer_config : config for the optimizers
text_encoder : text encoder model parameters
unet : unet model parameters
2023-04-29 20:56:10 -06:00
"""
2023-05-14 03:49:11 -06:00
def __init__ ( self , args , optimizer_config , text_encoder , unet , epoch_len ) :
2023-04-30 21:45:13 -06:00
del optimizer_config [ " doc " ]
2023-05-04 18:11:11 -06:00
print ( f " \n raw optimizer_config: " )
2023-04-30 21:45:13 -06:00
pprint . pprint ( optimizer_config )
2023-05-04 18:11:11 -06:00
self . epoch_len = epoch_len
self . te_config , self . base_config = self . get_final_optimizer_configs ( args , optimizer_config )
2023-05-14 03:49:11 -06:00
self . te_freeze_config = optimizer_config . get ( " text_encoder_freezing " , { } )
2023-05-04 18:11:11 -06:00
print ( f " final unet optimizer config: " )
pprint . pprint ( self . base_config )
print ( f " final text encoder optimizer config: " )
pprint . pprint ( self . te_config )
2023-04-29 20:56:10 -06:00
self . grad_accum = args . grad_accum
self . clip_grad_norm = args . clip_grad_norm
2023-05-14 03:49:11 -06:00
self . text_encoder_params = self . _apply_text_encoder_freeze ( text_encoder )
self . unet_params = unet . parameters ( )
2023-04-29 20:56:10 -06:00
2023-05-04 18:11:11 -06:00
self . optimizers = [ ]
2023-05-14 03:49:11 -06:00
self . optimizer_te , self . optimizer_unet = self . create_optimizers ( args ,
self . text_encoder_params ,
self . unet_params )
2023-05-04 18:11:11 -06:00
self . optimizers . append ( self . optimizer_te ) if self . optimizer_te is not None else None
self . optimizers . append ( self . optimizer_unet ) if self . optimizer_unet is not None else None
2023-04-29 20:56:10 -06:00
2023-05-04 18:11:11 -06:00
self . lr_schedulers = [ ]
schedulers = self . create_lr_schedulers ( args , optimizer_config )
self . lr_schedulers . extend ( schedulers )
print ( self . lr_schedulers )
2023-04-29 20:56:10 -06:00
2023-05-04 18:11:11 -06:00
self . load ( args . resume_ckpt )
2023-04-29 20:56:10 -06:00
self . scaler = GradScaler (
enabled = args . amp ,
init_scale = 2 * * 17.5 ,
growth_factor = 2 ,
backoff_factor = 1.0 / 2 ,
growth_interval = 25 ,
)
logging . info ( f " Grad scaler enabled: { self . scaler . is_enabled ( ) } (amp mode) " )
def step ( self , loss , step , global_step ) :
self . scaler . scale ( loss ) . backward ( )
2023-04-29 21:15:48 -06:00
if ( ( global_step + 1 ) % self . grad_accum == 0 ) or ( step == self . epoch_len - 1 ) :
2023-05-21 12:10:18 -06:00
if self . clip_grad_norm is not None :
for optimizer in self . optimizers :
self . scaler . unscale_ ( optimizer )
torch . nn . utils . clip_grad_norm_ ( parameters = self . unet_params , max_norm = self . clip_grad_norm )
torch . nn . utils . clip_grad_norm_ ( parameters = self . text_encoder_params , max_norm = self . clip_grad_norm )
2023-05-04 18:11:11 -06:00
for optimizer in self . optimizers :
self . scaler . step ( optimizer )
2023-04-30 21:45:13 -06:00
2023-04-29 20:56:10 -06:00
self . scaler . update ( )
self . _zero_grad ( set_to_none = True )
2023-05-04 18:11:11 -06:00
for scheduler in self . lr_schedulers :
scheduler . step ( )
2023-04-30 21:45:13 -06:00
self . _update_grad_scaler ( global_step )
2023-04-29 20:56:10 -06:00
def _zero_grad ( self , set_to_none = False ) :
2023-05-04 18:11:11 -06:00
for optimizer in self . optimizers :
optimizer . zero_grad ( set_to_none = set_to_none )
2023-04-29 20:56:10 -06:00
def get_scale ( self ) :
return self . scaler . get_scale ( )
def get_unet_lr ( self ) :
2023-05-04 18:11:11 -06:00
return self . optimizer_unet . param_groups [ 0 ] [ ' lr ' ] if self . optimizer_unet is not None else 0
2023-04-29 20:56:10 -06:00
2023-04-30 21:45:13 -06:00
def get_textenc_lr ( self ) :
2023-05-04 18:11:11 -06:00
return self . optimizer_te . param_groups [ 0 ] [ ' lr ' ] if self . optimizer_te is not None else 0
2023-04-29 20:56:10 -06:00
def save ( self , ckpt_path : str ) :
"""
Saves the optimizer states to path
"""
2023-05-04 18:11:11 -06:00
self . _save_optimizer ( self . optimizer_te , os . path . join ( ckpt_path , OPTIMIZER_TE_STATE_FILENAME ) ) if self . optimizer_te is not None else None
self . _save_optimizer ( self . optimizer_unet , os . path . join ( ckpt_path , OPTIMIZER_UNET_STATE_FILENAME ) ) if self . optimizer_unet is not None else None
2023-04-29 20:56:10 -06:00
2023-04-30 21:45:13 -06:00
def load ( self , ckpt_path : str ) :
"""
Loads the optimizer states from path
"""
te_optimizer_state_path = os . path . join ( ckpt_path , OPTIMIZER_TE_STATE_FILENAME )
unet_optimizer_state_path = os . path . join ( ckpt_path , OPTIMIZER_UNET_STATE_FILENAME )
2023-05-14 03:24:13 -06:00
if os . path . exists ( te_optimizer_state_path ) and self . optimizer_te is not None :
self . _load_optimizer ( self . optimizer_te , te_optimizer_state_path )
if os . path . exists ( unet_optimizer_state_path ) and self . optimizer_unet is not None :
self . _load_optimizer ( self . optimizer_unet , unet_optimizer_state_path )
2023-04-30 21:45:13 -06:00
2023-05-04 18:11:11 -06:00
def create_optimizers ( self , args , text_encoder_params , unet_params ) :
2023-04-29 20:56:10 -06:00
"""
2023-05-04 18:11:11 -06:00
creates optimizers from config and args for unet and text encoder
2023-04-29 20:56:10 -06:00
returns ( optimizer_te , optimizer_unet )
"""
2023-04-30 07:28:55 -06:00
2023-04-29 20:56:10 -06:00
if args . disable_textenc_training :
2023-05-04 18:11:11 -06:00
optimizer_te = None
2023-04-29 20:56:10 -06:00
else :
2023-05-14 03:49:11 -06:00
optimizer_te = self . _create_optimizer ( " text encoder " , args , self . te_config , text_encoder_params )
2023-04-29 20:56:10 -06:00
if args . disable_unet_training :
2023-05-04 18:11:11 -06:00
optimizer_unet = None
2023-04-29 20:56:10 -06:00
else :
2023-05-14 03:49:11 -06:00
optimizer_unet = self . _create_optimizer ( " unet " , args , self . base_config , unet_params )
2023-04-29 20:56:10 -06:00
return optimizer_te , optimizer_unet
2023-05-04 18:11:11 -06:00
def get_final_optimizer_configs ( self , args , global_optimizer_config ) :
2023-04-30 07:28:55 -06:00
"""
2023-05-26 19:54:02 -06:00
defaults and overrides based on priority
cli LR arg will override LR for both unet and text encoder for legacy reasons
2023-04-30 07:28:55 -06:00
"""
2023-05-04 18:11:11 -06:00
base_config = global_optimizer_config . get ( " base " )
te_config = global_optimizer_config . get ( " text_encoder_overrides " )
2023-04-30 07:28:55 -06:00
2023-04-30 21:45:13 -06:00
if args . lr_decay_steps is None or args . lr_decay_steps < 1 :
2023-05-26 19:54:02 -06:00
# sets cosine so the zero crossing is past the end of training, this results in a terminal LR that is about 25% of the nominal LR
2023-04-30 21:45:13 -06:00
args . lr_decay_steps = int ( self . epoch_len * args . max_epochs * 1.5 )
2023-05-04 18:11:11 -06:00
if args . lr_warmup_steps is None :
2023-05-26 19:54:02 -06:00
# set warmup to 2% of decay, if decay was autoset to 150% of max epochs then warmup will end up about 3% of max epochs
2023-05-04 18:11:11 -06:00
args . lr_warmup_steps = int ( args . lr_decay_steps / 50 )
if args . lr is not None :
2023-05-26 19:54:02 -06:00
# override for legacy support reasons
2023-05-04 18:11:11 -06:00
base_config [ " lr " ] = args . lr
base_config [ " optimizer " ] = base_config . get ( " optimizer " , None ) or " adamw8bit "
base_config [ " lr_warmup_steps " ] = base_config . get ( " lr_warmup_steps " , None ) or args . lr_warmup_steps
base_config [ " lr_decay_steps " ] = base_config . get ( " lr_decay_steps " , None ) or args . lr_decay_steps
base_config [ " lr_scheduler " ] = base_config . get ( " lr_scheduler " , None ) or args . lr_scheduler
base_config [ " lr_warmup_steps " ] = base_config . get ( " lr_warmup_steps " , None ) or args . lr_warmup_steps
base_config [ " lr_decay_steps " ] = base_config . get ( " lr_decay_steps " , None ) or args . lr_decay_steps
base_config [ " lr_scheduler " ] = base_config . get ( " lr_scheduler " , None ) or args . lr_scheduler
te_config [ " lr " ] = te_config . get ( " lr " , None ) or base_config [ " lr " ]
te_config [ " optimizer " ] = te_config . get ( " optimizer " , None ) or base_config [ " optimizer " ]
te_config [ " lr_scheduler " ] = te_config . get ( " lr_scheduler " , None ) or base_config [ " lr_scheduler " ]
te_config [ " lr_warmup_steps " ] = te_config . get ( " lr_warmup_steps " , None ) or base_config [ " lr_warmup_steps " ]
te_config [ " lr_decay_steps " ] = te_config . get ( " lr_decay_steps " , None ) or base_config [ " lr_decay_steps " ]
te_config [ " weight_decay " ] = te_config . get ( " weight_decay " , None ) or base_config [ " weight_decay " ]
te_config [ " betas " ] = te_config . get ( " betas " , None ) or base_config [ " betas " ]
te_config [ " epsilon " ] = te_config . get ( " epsilon " , None ) or base_config [ " epsilon " ]
return te_config , base_config
def create_lr_schedulers ( self , args , optimizer_config ) :
unet_config = optimizer_config [ " base " ]
te_config = optimizer_config [ " text_encoder_overrides " ]
ret_val = [ ]
if self . optimizer_te is not None :
lr_scheduler = get_scheduler (
te_config . get ( " lr_scheduler " , args . lr_scheduler ) ,
optimizer = self . optimizer_te ,
num_warmup_steps = te_config . get ( " lr_warmup_steps " , None ) ,
num_training_steps = unet_config . get ( " lr_decay_steps " , None ) or unet_config [ " lr_decay_steps " ]
)
ret_val . append ( lr_scheduler )
if self . optimizer_unet is not None :
unet_config = optimizer_config [ " base " ]
lr_scheduler = get_scheduler (
unet_config [ " lr_scheduler " ] ,
optimizer = self . optimizer_unet ,
num_warmup_steps = int ( unet_config [ " lr_warmup_steps " ] ) ,
num_training_steps = int ( unet_config [ " lr_decay_steps " ] ) ,
)
ret_val . append ( lr_scheduler )
return ret_val
2023-04-29 20:56:10 -06:00
2023-04-30 21:45:13 -06:00
def _update_grad_scaler ( self , global_step ) :
2023-04-29 20:56:10 -06:00
if global_step == 500 :
factor = 1.8
self . scaler . set_growth_factor ( factor )
self . scaler . set_backoff_factor ( 1 / factor )
2023-04-30 21:45:13 -06:00
self . scaler . set_growth_interval ( 100 )
2023-04-29 20:56:10 -06:00
if global_step == 1000 :
factor = 1.6
self . scaler . set_growth_factor ( factor )
self . scaler . set_backoff_factor ( 1 / factor )
2023-04-30 21:45:13 -06:00
self . scaler . set_growth_interval ( 200 )
2023-04-29 20:56:10 -06:00
if global_step == 2000 :
factor = 1.3
self . scaler . set_growth_factor ( factor )
self . scaler . set_backoff_factor ( 1 / factor )
2023-05-04 18:11:11 -06:00
self . scaler . set_growth_interval ( 500 )
2023-04-29 20:56:10 -06:00
if global_step == 4000 :
factor = 1.15
self . scaler . set_growth_factor ( factor )
self . scaler . set_backoff_factor ( 1 / factor )
2023-04-30 21:45:13 -06:00
self . scaler . set_growth_interval ( 2000 )
2023-04-29 20:56:10 -06:00
@staticmethod
def _save_optimizer ( optimizer , path : str ) :
"""
Saves the optimizer state to specific path / filename
"""
torch . save ( optimizer . state_dict ( ) , path )
@staticmethod
2023-04-30 21:45:13 -06:00
def _load_optimizer ( optimizer : torch . optim . Optimizer , path : str ) :
2023-04-29 20:56:10 -06:00
"""
Loads the optimizer state to an Optimizer object
2023-04-30 21:45:13 -06:00
optimizer : torch . optim . Optimizer
path : . pt file
2023-04-29 20:56:10 -06:00
"""
2023-04-30 21:45:13 -06:00
try :
optimizer . load_state_dict ( torch . load ( path ) )
logging . info ( f " Loaded optimizer state from { path } " )
except Exception as e :
logging . warning ( f " { Fore . LIGHTYELLOW_EX } **Failed to load optimizer state from { path } , optimizer state will not be loaded, \n * Exception: { e } { Style . RESET_ALL } " )
pass
2023-04-29 20:56:10 -06:00
2023-05-14 03:49:11 -06:00
def _create_optimizer ( self , label , args , local_optimizer_config , parameters ) :
2023-04-29 20:56:10 -06:00
betas = BETAS_DEFAULT
epsilon = EPSILON_DEFAULT
weight_decay = WEIGHT_DECAY_DEFAULT
opt_class = None
optimizer = None
default_lr = 1e-6
curr_lr = args . lr
if local_optimizer_config is not None :
2023-04-30 07:28:55 -06:00
betas = local_optimizer_config [ " betas " ] or betas
epsilon = local_optimizer_config [ " epsilon " ] or epsilon
weight_decay = local_optimizer_config [ " weight_decay " ] or weight_decay
2023-05-04 18:11:11 -06:00
optimizer_name = local_optimizer_config [ " optimizer " ] or " adamw8bit "
2023-04-29 20:56:10 -06:00
curr_lr = local_optimizer_config . get ( " lr " , curr_lr )
if args . lr is not None :
curr_lr = args . lr
logging . info ( f " Overriding LR from optimizer config with main config/cli LR setting: { curr_lr } " )
if curr_lr is None :
curr_lr = default_lr
logging . warning ( f " No LR setting found, defaulting to { default_lr } " )
if optimizer_name :
if optimizer_name == " lion " :
from lion_pytorch import Lion
opt_class = Lion
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
betas = ( betas [ 0 ] , betas [ 1 ] ) ,
weight_decay = weight_decay ,
)
2023-05-04 18:11:11 -06:00
elif optimizer_name == " adamw " :
2023-04-29 20:56:10 -06:00
opt_class = torch . optim . AdamW
else :
import bitsandbytes as bnb
opt_class = bnb . optim . AdamW8bit
if not optimizer :
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
betas = ( betas [ 0 ] , betas [ 1 ] ) ,
eps = epsilon ,
weight_decay = weight_decay ,
amsgrad = False ,
)
2023-05-14 03:49:11 -06:00
log_optimizer ( label , optimizer , betas , epsilon , weight_decay , curr_lr )
2023-04-29 20:56:10 -06:00
return optimizer
2023-05-14 03:49:11 -06:00
def _apply_text_encoder_freeze ( self , text_encoder ) - > chain [ Any ] :
parameters = itertools . chain ( [ ] )
if self . te_freeze_config . get ( ' freeze_embeddings ' , False ) :
# freeze embeddings
print ( " ❄️ freezing embeddings " )
else :
parameters = itertools . chain ( parameters , text_encoder . text_model . embeddings . parameters ( ) )
freeze_front_n_layers = self . te_freeze_config . get ( ' freeze_front_n_layers ' , None )
if freeze_front_n_layers is None :
parameters = itertools . chain ( parameters , text_encoder . text_model . encoder . layers . parameters ( ) )
else :
# freeze the specified CLIP text encoder layers
layers = text_encoder . text_model . encoder . layers
print ( f " ❄️ freezing text encoder layers 0- { len ( layers [ : freeze_front_n_layers ] ) } of { len ( layers ) } " )
parameters = itertools . chain ( parameters , layers [ freeze_front_n_layers : ] . parameters ( ) )
if self . te_freeze_config . get ( ' freeze_final_layer_norm ' , False ) :
# instead of freezing the final layer norm parameters, we simply do not return them
print ( " ❄️ freezing final layer norm " )
else :
parameters = itertools . chain ( parameters , text_encoder . text_model . final_layer_norm . parameters ( ) )
return parameters
def log_optimizer ( label : str , optimizer : torch . optim . Optimizer , betas , epsilon , weight_decay , lr ) :
2023-04-29 20:56:10 -06:00
"""
logs the optimizer settings
"""
2023-05-14 03:49:11 -06:00
all_params = sum ( [ g [ ' params ' ] for g in optimizer . param_groups ] , [ ] )
frozen_parameter_count = len ( [ p for p in all_params if not p . requires_grad ] )
total_parameter_count = len ( all_params )
if frozen_parameter_count > 0 :
param_info = f " ( { total_parameter_count } parameters, { frozen_parameter_count } frozen) "
else :
param_info = f " ( { total_parameter_count } parameters) "
logging . info ( f " { Fore . CYAN } * { label } optimizer: { optimizer . __class__ . __name__ } { param_info } * { Style . RESET_ALL } " )
2023-04-29 20:56:10 -06:00
logging . info ( f " { Fore . CYAN } lr: { lr } , betas: { betas } , epsilon: { epsilon } , weight_decay: { weight_decay } * { Style . RESET_ALL } " )
2023-05-14 03:49:11 -06:00