2023-05-04 18:11:11 -06:00
"""
Copyright [ 2022 - 2023 ] Victor C Hall
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2023-04-29 20:56:10 -06:00
import logging
import itertools
import os
2023-05-14 03:49:11 -06:00
from itertools import chain
from typing import Generator , Any
2023-04-29 20:56:10 -06:00
import torch
from torch . cuda . amp import autocast , GradScaler
from diffusers . optimization import get_scheduler
from colorama import Fore , Style
2023-04-30 21:45:13 -06:00
import pprint
2023-04-29 20:56:10 -06:00
BETAS_DEFAULT = [ 0.9 , 0.999 ]
EPSILON_DEFAULT = 1e-8
WEIGHT_DECAY_DEFAULT = 0.01
LR_DEFAULT = 1e-6
OPTIMIZER_TE_STATE_FILENAME = " optimizer_te.pt "
OPTIMIZER_UNET_STATE_FILENAME = " optimizer_unet.pt "
class EveryDreamOptimizer ( ) :
"""
Wrapper to manage optimizers
2023-04-30 21:45:13 -06:00
resume_ckpt_path : path to resume checkpoint , will try to load state ( . pt ) files if they exist
optimizer_config : config for the optimizers
text_encoder : text encoder model parameters
unet : unet model parameters
2023-04-29 20:56:10 -06:00
"""
2023-11-02 19:54:29 -06:00
def __init__ ( self , args , optimizer_config , text_encoder , unet , epoch_len , log_writer = None ) :
2023-04-30 21:45:13 -06:00
del optimizer_config [ " doc " ]
2023-05-04 18:11:11 -06:00
print ( f " \n raw optimizer_config: " )
2023-04-30 21:45:13 -06:00
pprint . pprint ( optimizer_config )
2023-05-04 18:11:11 -06:00
self . epoch_len = epoch_len
2023-11-02 19:54:29 -06:00
self . unet = unet # needed for weight norm logging, unet.parameters() has to be called again, Diffusers quirk
self . log_writer = log_writer
2023-05-04 18:11:11 -06:00
self . te_config , self . base_config = self . get_final_optimizer_configs ( args , optimizer_config )
2023-05-14 03:49:11 -06:00
self . te_freeze_config = optimizer_config . get ( " text_encoder_freezing " , { } )
2023-06-03 09:27:17 -06:00
print ( f " Final unet optimizer config: " )
2023-05-04 18:11:11 -06:00
pprint . pprint ( self . base_config )
2023-06-03 09:27:17 -06:00
print ( f " Final text encoder optimizer config: " )
2023-05-04 18:11:11 -06:00
pprint . pprint ( self . te_config )
2023-04-29 20:56:10 -06:00
self . grad_accum = args . grad_accum
self . clip_grad_norm = args . clip_grad_norm
2023-06-17 11:18:04 -06:00
self . apply_grad_scaler_step_tweaks = optimizer_config . get ( " apply_grad_scaler_step_tweaks " , True )
2023-11-02 19:54:29 -06:00
self . log_grad_norm = optimizer_config . get ( " log_grad_norm " , True )
2023-05-14 03:49:11 -06:00
self . text_encoder_params = self . _apply_text_encoder_freeze ( text_encoder )
self . unet_params = unet . parameters ( )
2023-04-29 20:56:10 -06:00
2023-11-02 19:54:29 -06:00
with torch . no_grad ( ) :
log_action = lambda n , label : logging . info ( f " { Fore . LIGHTBLUE_EX } { label } weight normal: { n } { Style . RESET_ALL } " )
self . _log_weight_normal ( text_encoder . text_model . encoder . layers . parameters ( ) , " text encoder " , log_action )
self . _log_weight_normal ( unet . parameters ( ) , " unet " , log_action )
2023-05-04 18:11:11 -06:00
self . optimizers = [ ]
2023-05-14 03:49:11 -06:00
self . optimizer_te , self . optimizer_unet = self . create_optimizers ( args ,
self . text_encoder_params ,
self . unet_params )
2023-05-04 18:11:11 -06:00
self . optimizers . append ( self . optimizer_te ) if self . optimizer_te is not None else None
self . optimizers . append ( self . optimizer_unet ) if self . optimizer_unet is not None else None
2023-04-29 20:56:10 -06:00
2023-05-04 18:11:11 -06:00
self . lr_schedulers = [ ]
schedulers = self . create_lr_schedulers ( args , optimizer_config )
self . lr_schedulers . extend ( schedulers )
2023-04-29 20:56:10 -06:00
2023-05-04 18:11:11 -06:00
self . load ( args . resume_ckpt )
2023-04-29 20:56:10 -06:00
self . scaler = GradScaler (
enabled = args . amp ,
init_scale = 2 * * 17.5 ,
growth_factor = 2 ,
backoff_factor = 1.0 / 2 ,
growth_interval = 25 ,
)
logging . info ( f " Grad scaler enabled: { self . scaler . is_enabled ( ) } (amp mode) " )
2023-11-02 19:54:29 -06:00
def _log_gradient_normal ( self , parameters : Generator , label : str , log_action = None ) :
total_norm = self . _get_norm ( parameters , lambda p : p . grad . data )
log_action ( total_norm , label )
def _log_weight_normal ( self , parameters : Generator , label : str , log_action = None ) :
total_norm = self . _get_norm ( parameters , lambda p : p . data )
log_action ( total_norm , label )
def _calculate_normal ( param , param_type ) :
if param_type ( param ) is not None :
return param_type ( param ) . norm ( 2 ) . item ( ) * * 2
else :
return 0.0
def _get_norm ( self , parameters : Generator , param_type ) :
total_norm = 0
for p in parameters :
param = param_type ( p )
total_norm + = self . _calculate_norm ( param , p )
total_norm = total_norm * * ( 1. / 2 )
return total_norm
def _calculate_norm ( self , param , p ) :
if param is not None :
return param . norm ( 2 ) . item ( ) * * 2
else :
return 0.0
2023-04-29 20:56:10 -06:00
def step ( self , loss , step , global_step ) :
self . scaler . scale ( loss ) . backward ( )
2023-04-29 21:15:48 -06:00
if ( ( global_step + 1 ) % self . grad_accum == 0 ) or ( step == self . epoch_len - 1 ) :
2023-05-21 12:10:18 -06:00
if self . clip_grad_norm is not None :
for optimizer in self . optimizers :
self . scaler . unscale_ ( optimizer )
2023-11-02 19:54:29 -06:00
if self . log_grad_norm :
pre_clip_norm = torch . nn . utils . clip_grad_norm_ ( parameters = self . unet . parameters ( ) , max_norm = float ( ' inf ' ) )
self . log_writer . add_scalar ( " optimizer/unet_pre_clip_norm " , pre_clip_norm , global_step )
pre_clip_norm = torch . nn . utils . clip_grad_norm_ ( parameters = self . text_encoder_params , max_norm = float ( ' inf ' ) )
self . log_writer . add_scalar ( " optimizer/te_pre_clip_norm " , pre_clip_norm , global_step )
unet_grad_norm = torch . nn . utils . clip_grad_norm_ ( parameters = self . unet . parameters ( ) , max_norm = self . clip_grad_norm )
self . log_writer . add_scalar ( " optimizer/unet_grad_norm " , unet_grad_norm , global_step )
te_grad_norm = torch . nn . utils . clip_grad_norm_ ( parameters = self . text_encoder_params , max_norm = self . clip_grad_norm )
self . log_writer . add_scalar ( " optimizer/te_grad_norm " , te_grad_norm , global_step )
2023-05-21 12:10:18 -06:00
2023-05-04 18:11:11 -06:00
for optimizer in self . optimizers :
self . scaler . step ( optimizer )
2023-11-02 19:54:29 -06:00
2023-04-29 20:56:10 -06:00
self . scaler . update ( )
2023-11-02 19:54:29 -06:00
if self . log_grad_norm and self . log_writer :
log_info_unet_fn = lambda n , label : self . log_writer . add_scalar ( label , n , global_step )
log_info_te_fn = lambda n , label : self . log_writer . add_scalar ( label , n , global_step )
with torch . no_grad ( ) :
self . _log_gradient_normal ( self . unet_params , " optimizer/unet_grad_norm " , log_info_unet_fn )
self . _log_gradient_normal ( self . text_encoder_params , " optimizer/te_grad_norm " , log_info_te_fn )
2023-04-29 20:56:10 -06:00
self . _zero_grad ( set_to_none = True )
2023-05-04 18:11:11 -06:00
for scheduler in self . lr_schedulers :
scheduler . step ( )
2023-06-17 11:18:04 -06:00
if self . apply_grad_scaler_step_tweaks :
self . _update_grad_scaler ( global_step )
2023-04-29 20:56:10 -06:00
def _zero_grad ( self , set_to_none = False ) :
2023-05-04 18:11:11 -06:00
for optimizer in self . optimizers :
optimizer . zero_grad ( set_to_none = set_to_none )
2023-04-29 20:56:10 -06:00
def get_scale ( self ) :
return self . scaler . get_scale ( )
def get_unet_lr ( self ) :
2023-05-04 18:11:11 -06:00
return self . optimizer_unet . param_groups [ 0 ] [ ' lr ' ] if self . optimizer_unet is not None else 0
2023-04-29 20:56:10 -06:00
2023-04-30 21:45:13 -06:00
def get_textenc_lr ( self ) :
2023-05-04 18:11:11 -06:00
return self . optimizer_te . param_groups [ 0 ] [ ' lr ' ] if self . optimizer_te is not None else 0
2023-04-29 20:56:10 -06:00
def save ( self , ckpt_path : str ) :
"""
Saves the optimizer states to path
"""
2023-05-04 18:11:11 -06:00
self . _save_optimizer ( self . optimizer_te , os . path . join ( ckpt_path , OPTIMIZER_TE_STATE_FILENAME ) ) if self . optimizer_te is not None else None
self . _save_optimizer ( self . optimizer_unet , os . path . join ( ckpt_path , OPTIMIZER_UNET_STATE_FILENAME ) ) if self . optimizer_unet is not None else None
2023-04-29 20:56:10 -06:00
2023-04-30 21:45:13 -06:00
def load ( self , ckpt_path : str ) :
"""
Loads the optimizer states from path
"""
te_optimizer_state_path = os . path . join ( ckpt_path , OPTIMIZER_TE_STATE_FILENAME )
unet_optimizer_state_path = os . path . join ( ckpt_path , OPTIMIZER_UNET_STATE_FILENAME )
2023-05-14 03:24:13 -06:00
if os . path . exists ( te_optimizer_state_path ) and self . optimizer_te is not None :
self . _load_optimizer ( self . optimizer_te , te_optimizer_state_path )
if os . path . exists ( unet_optimizer_state_path ) and self . optimizer_unet is not None :
self . _load_optimizer ( self . optimizer_unet , unet_optimizer_state_path )
2023-04-30 21:45:13 -06:00
2023-05-04 18:11:11 -06:00
def create_optimizers ( self , args , text_encoder_params , unet_params ) :
2023-04-29 20:56:10 -06:00
"""
2023-05-04 18:11:11 -06:00
creates optimizers from config and args for unet and text encoder
2023-04-29 20:56:10 -06:00
returns ( optimizer_te , optimizer_unet )
"""
2023-04-30 07:28:55 -06:00
2023-04-29 20:56:10 -06:00
if args . disable_textenc_training :
2023-05-04 18:11:11 -06:00
optimizer_te = None
2023-04-29 20:56:10 -06:00
else :
2023-05-14 03:49:11 -06:00
optimizer_te = self . _create_optimizer ( " text encoder " , args , self . te_config , text_encoder_params )
2023-04-29 20:56:10 -06:00
if args . disable_unet_training :
2023-05-04 18:11:11 -06:00
optimizer_unet = None
2023-04-29 20:56:10 -06:00
else :
2023-05-14 03:49:11 -06:00
optimizer_unet = self . _create_optimizer ( " unet " , args , self . base_config , unet_params )
2023-04-29 20:56:10 -06:00
return optimizer_te , optimizer_unet
2023-05-04 18:11:11 -06:00
def get_final_optimizer_configs ( self , args , global_optimizer_config ) :
2023-04-30 07:28:55 -06:00
"""
2023-05-26 19:54:02 -06:00
defaults and overrides based on priority
cli LR arg will override LR for both unet and text encoder for legacy reasons
2023-04-30 07:28:55 -06:00
"""
2023-05-04 18:11:11 -06:00
base_config = global_optimizer_config . get ( " base " )
te_config = global_optimizer_config . get ( " text_encoder_overrides " )
2023-04-30 07:28:55 -06:00
2023-04-30 21:45:13 -06:00
if args . lr_decay_steps is None or args . lr_decay_steps < 1 :
2023-05-26 19:54:02 -06:00
# sets cosine so the zero crossing is past the end of training, this results in a terminal LR that is about 25% of the nominal LR
2023-04-30 21:45:13 -06:00
args . lr_decay_steps = int ( self . epoch_len * args . max_epochs * 1.5 )
2023-05-04 18:11:11 -06:00
if args . lr_warmup_steps is None :
2023-05-26 19:54:02 -06:00
# set warmup to 2% of decay, if decay was autoset to 150% of max epochs then warmup will end up about 3% of max epochs
2023-05-04 18:11:11 -06:00
args . lr_warmup_steps = int ( args . lr_decay_steps / 50 )
if args . lr is not None :
2023-05-26 19:54:02 -06:00
# override for legacy support reasons
2023-05-04 18:11:11 -06:00
base_config [ " lr " ] = args . lr
base_config [ " optimizer " ] = base_config . get ( " optimizer " , None ) or " adamw8bit "
base_config [ " lr_warmup_steps " ] = base_config . get ( " lr_warmup_steps " , None ) or args . lr_warmup_steps
base_config [ " lr_decay_steps " ] = base_config . get ( " lr_decay_steps " , None ) or args . lr_decay_steps
base_config [ " lr_scheduler " ] = base_config . get ( " lr_scheduler " , None ) or args . lr_scheduler
base_config [ " lr_warmup_steps " ] = base_config . get ( " lr_warmup_steps " , None ) or args . lr_warmup_steps
base_config [ " lr_decay_steps " ] = base_config . get ( " lr_decay_steps " , None ) or args . lr_decay_steps
base_config [ " lr_scheduler " ] = base_config . get ( " lr_scheduler " , None ) or args . lr_scheduler
te_config [ " lr " ] = te_config . get ( " lr " , None ) or base_config [ " lr " ]
te_config [ " optimizer " ] = te_config . get ( " optimizer " , None ) or base_config [ " optimizer " ]
te_config [ " lr_scheduler " ] = te_config . get ( " lr_scheduler " , None ) or base_config [ " lr_scheduler " ]
te_config [ " lr_warmup_steps " ] = te_config . get ( " lr_warmup_steps " , None ) or base_config [ " lr_warmup_steps " ]
te_config [ " lr_decay_steps " ] = te_config . get ( " lr_decay_steps " , None ) or base_config [ " lr_decay_steps " ]
te_config [ " weight_decay " ] = te_config . get ( " weight_decay " , None ) or base_config [ " weight_decay " ]
te_config [ " betas " ] = te_config . get ( " betas " , None ) or base_config [ " betas " ]
te_config [ " epsilon " ] = te_config . get ( " epsilon " , None ) or base_config [ " epsilon " ]
return te_config , base_config
2023-06-27 18:53:48 -06:00
def create_lr_schedulers ( self , args , optimizer_config ) :
2023-05-04 18:11:11 -06:00
unet_config = optimizer_config [ " base " ]
te_config = optimizer_config [ " text_encoder_overrides " ]
ret_val = [ ]
if self . optimizer_te is not None :
lr_scheduler = get_scheduler (
te_config . get ( " lr_scheduler " , args . lr_scheduler ) ,
optimizer = self . optimizer_te ,
2023-06-24 12:41:16 -06:00
num_warmup_steps = int ( te_config . get ( " lr_warmup_steps " , None ) ) or unet_config [ " lr_warmup_steps " ] ,
num_training_steps = int ( te_config . get ( " lr_decay_steps " , None ) ) or unet_config [ " lr_decay_steps " ]
2023-05-04 18:11:11 -06:00
)
ret_val . append ( lr_scheduler )
if self . optimizer_unet is not None :
unet_config = optimizer_config [ " base " ]
lr_scheduler = get_scheduler (
unet_config [ " lr_scheduler " ] ,
optimizer = self . optimizer_unet ,
num_warmup_steps = int ( unet_config [ " lr_warmup_steps " ] ) ,
num_training_steps = int ( unet_config [ " lr_decay_steps " ] ) ,
)
ret_val . append ( lr_scheduler )
return ret_val
2023-04-29 20:56:10 -06:00
2023-04-30 21:45:13 -06:00
def _update_grad_scaler ( self , global_step ) :
2023-04-29 20:56:10 -06:00
if global_step == 500 :
factor = 1.8
self . scaler . set_growth_factor ( factor )
self . scaler . set_backoff_factor ( 1 / factor )
2023-04-30 21:45:13 -06:00
self . scaler . set_growth_interval ( 100 )
2023-04-29 20:56:10 -06:00
if global_step == 1000 :
factor = 1.6
self . scaler . set_growth_factor ( factor )
self . scaler . set_backoff_factor ( 1 / factor )
2023-04-30 21:45:13 -06:00
self . scaler . set_growth_interval ( 200 )
2023-04-29 20:56:10 -06:00
if global_step == 2000 :
factor = 1.3
self . scaler . set_growth_factor ( factor )
self . scaler . set_backoff_factor ( 1 / factor )
2023-05-04 18:11:11 -06:00
self . scaler . set_growth_interval ( 500 )
2023-04-29 20:56:10 -06:00
if global_step == 4000 :
factor = 1.15
self . scaler . set_growth_factor ( factor )
self . scaler . set_backoff_factor ( 1 / factor )
2023-04-30 21:45:13 -06:00
self . scaler . set_growth_interval ( 2000 )
2023-04-29 20:56:10 -06:00
@staticmethod
def _save_optimizer ( optimizer , path : str ) :
"""
Saves the optimizer state to specific path / filename
"""
torch . save ( optimizer . state_dict ( ) , path )
@staticmethod
2023-04-30 21:45:13 -06:00
def _load_optimizer ( optimizer : torch . optim . Optimizer , path : str ) :
2023-04-29 20:56:10 -06:00
"""
Loads the optimizer state to an Optimizer object
2023-04-30 21:45:13 -06:00
optimizer : torch . optim . Optimizer
path : . pt file
2023-04-29 20:56:10 -06:00
"""
2023-04-30 21:45:13 -06:00
try :
optimizer . load_state_dict ( torch . load ( path ) )
logging . info ( f " Loaded optimizer state from { path } " )
except Exception as e :
logging . warning ( f " { Fore . LIGHTYELLOW_EX } **Failed to load optimizer state from { path } , optimizer state will not be loaded, \n * Exception: { e } { Style . RESET_ALL } " )
pass
2023-04-29 20:56:10 -06:00
2023-05-14 03:49:11 -06:00
def _create_optimizer ( self , label , args , local_optimizer_config , parameters ) :
2023-04-29 20:56:10 -06:00
betas = BETAS_DEFAULT
epsilon = EPSILON_DEFAULT
weight_decay = WEIGHT_DECAY_DEFAULT
opt_class = None
optimizer = None
default_lr = 1e-6
curr_lr = args . lr
2023-06-24 12:41:16 -06:00
d0 = 1e-6 # dadapt
2023-06-03 09:27:17 -06:00
decouple = True # seems bad to turn off, dadapt_adam only
2023-06-24 12:41:16 -06:00
momentum = 0.0 # dadapt_sgd
no_prox = False # ????, dadapt_adan
2023-09-22 10:15:32 -06:00
use_bias_correction = True # suggest by prodigy github
2023-06-27 18:53:48 -06:00
growth_rate = float ( " inf " ) # dadapt various, no idea what a sane default is
2023-11-02 19:54:29 -06:00
safeguard_warmup = True # per recommendation from prodigy documentation
2023-04-29 20:56:10 -06:00
if local_optimizer_config is not None :
2023-06-24 12:41:16 -06:00
betas = local_optimizer_config . get ( " betas " , betas )
epsilon = local_optimizer_config . get ( " epsilon " , epsilon )
weight_decay = local_optimizer_config . get ( " weight_decay " , weight_decay )
no_prox = local_optimizer_config . get ( " no_prox " , False )
optimizer_name = local_optimizer_config . get ( " optimizer " , " adamw8bit " )
2023-04-29 20:56:10 -06:00
curr_lr = local_optimizer_config . get ( " lr " , curr_lr )
2023-06-03 09:27:17 -06:00
d0 = local_optimizer_config . get ( " d0 " , d0 )
decouple = local_optimizer_config . get ( " decouple " , decouple )
momentum = local_optimizer_config . get ( " momentum " , momentum )
2023-06-24 12:41:16 -06:00
growth_rate = local_optimizer_config . get ( " growth_rate " , growth_rate )
2023-11-02 19:54:29 -06:00
safeguard_warmup = local_optimizer_config . get ( " safeguard_warmup " , safeguard_warmup )
2023-04-29 20:56:10 -06:00
if args . lr is not None :
curr_lr = args . lr
logging . info ( f " Overriding LR from optimizer config with main config/cli LR setting: { curr_lr } " )
if curr_lr is None :
curr_lr = default_lr
logging . warning ( f " No LR setting found, defaulting to { default_lr } " )
if optimizer_name :
if optimizer_name == " lion " :
from lion_pytorch import Lion
opt_class = Lion
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
betas = ( betas [ 0 ] , betas [ 1 ] ) ,
weight_decay = weight_decay ,
)
2023-09-21 11:47:26 -06:00
elif optimizer_name == " lion8bit " :
from bitsandbytes . optim import Lion8bit
2023-09-22 10:15:32 -06:00
opt_class = Lion8bit
2023-09-21 11:47:26 -06:00
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
betas = ( betas [ 0 ] , betas [ 1 ] ) ,
weight_decay = weight_decay ,
percentile_clipping = 100 ,
min_8bit_size = 4096 ,
2023-09-22 10:15:32 -06:00
)
elif optimizer_name == " prodigy " :
from prodigyopt import Prodigy
opt_class = Prodigy
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
weight_decay = weight_decay ,
use_bias_correction = use_bias_correction ,
growth_rate = growth_rate ,
2023-09-21 11:47:26 -06:00
d0 = d0 ,
2023-09-22 10:15:32 -06:00
safeguard_warmup = safeguard_warmup
2023-09-21 11:47:26 -06:00
)
2023-05-04 18:11:11 -06:00
elif optimizer_name == " adamw " :
2023-04-29 20:56:10 -06:00
opt_class = torch . optim . AdamW
2023-06-08 10:36:20 -06:00
if " dowg " in optimizer_name :
# coordinate_dowg, scalar_dowg require no additional parameters. Epsilon is overrideable but is unnecessary in all stable diffusion training situations.
import dowg
2023-06-09 20:40:17 -06:00
if optimizer_name == " coordinate_dowg " :
opt_class = dowg . CoordinateDoWG
2023-06-09 20:56:05 -06:00
elif optimizer_name == " scalar_dowg " :
2023-06-08 10:36:20 -06:00
opt_class = dowg . ScalarDoWG
2023-06-08 09:39:46 -06:00
else :
2023-09-22 10:15:32 -06:00
raise ValueError ( f " Unknown DoWG optimizer { optimizer_name } . Available options are ' coordinate_dowg ' and ' scalar_dowg ' " )
2023-06-03 09:27:17 -06:00
elif optimizer_name in [ " dadapt_adam " , " dadapt_lion " , " dadapt_sgd " ] :
2023-06-03 09:26:53 -06:00
import dadaptation
if curr_lr < 1e-4 :
logging . warning ( f " { Fore . YELLOW } LR, { curr_lr } , is very low for Dadaptation. Consider reviewing Dadaptation documentation, but proceeding anyway. { Style . RESET_ALL } " )
if weight_decay < 1e-3 :
logging . warning ( f " { Fore . YELLOW } Weight decay, { weight_decay } , is very low for Dadaptation. Consider reviewing Dadaptation documentation, but proceeding anyway. { Style . RESET_ALL } " )
if optimizer_name == " dadapt_adam " :
opt_class = dadaptation . DAdaptAdam
2023-06-03 19:41:56 -06:00
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
betas = ( betas [ 0 ] , betas [ 1 ] ) ,
weight_decay = weight_decay ,
eps = epsilon , #unused for lion
d0 = d0 ,
log_every = args . log_step ,
2023-06-27 18:53:48 -06:00
growth_rate = growth_rate ,
2023-06-03 19:41:56 -06:00
decouple = decouple ,
)
2023-06-24 12:41:16 -06:00
elif optimizer_name == " dadapt_adan " :
opt_class = dadaptation . DAdaptAdan
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
betas = ( betas [ 0 ] , betas [ 1 ] ) ,
no_prox = no_prox ,
weight_decay = weight_decay ,
eps = epsilon ,
d0 = d0 ,
log_every = args . log_step ,
growth_rate = growth_rate ,
)
2023-06-03 09:26:53 -06:00
elif optimizer_name == " dadapt_lion " :
opt_class = dadaptation . DAdaptLion
2023-06-03 19:41:56 -06:00
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
betas = ( betas [ 0 ] , betas [ 1 ] ) ,
weight_decay = weight_decay ,
d0 = d0 ,
log_every = args . log_step ,
)
2023-06-03 09:27:17 -06:00
elif optimizer_name == " dadapt_sgd " :
opt_class = dadaptation . DAdaptSGD
2023-06-03 19:41:56 -06:00
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
momentum = momentum ,
weight_decay = weight_decay ,
d0 = d0 ,
log_every = args . log_step ,
2023-06-27 18:53:48 -06:00
growth_rate = growth_rate ,
2023-06-03 19:41:56 -06:00
)
2023-06-03 09:26:53 -06:00
2023-04-29 20:56:10 -06:00
else :
import bitsandbytes as bnb
opt_class = bnb . optim . AdamW8bit
if not optimizer :
optimizer = opt_class (
itertools . chain ( parameters ) ,
lr = curr_lr ,
betas = ( betas [ 0 ] , betas [ 1 ] ) ,
eps = epsilon ,
weight_decay = weight_decay ,
amsgrad = False ,
)
2023-05-14 03:49:11 -06:00
log_optimizer ( label , optimizer , betas , epsilon , weight_decay , curr_lr )
2023-04-29 20:56:10 -06:00
return optimizer
2023-05-14 03:49:11 -06:00
def _apply_text_encoder_freeze ( self , text_encoder ) - > chain [ Any ] :
2023-06-17 10:54:06 -06:00
num_layers = len ( text_encoder . text_model . encoder . layers )
unfreeze_embeddings = True
unfreeze_last_n_layers = None
unfreeze_final_layer_norm = True
if " freeze_front_n_layers " in self . te_freeze_config :
logging . warning (
' * Found " freeze_front_n_layers " in JSON, please use " unfreeze_last_n_layers " instead ' )
freeze_front_n_layers = self . te_freeze_config [ " freeze_front_n_layers " ]
if freeze_front_n_layers < 0 :
# eg -2 = freeze all but the last 2
unfreeze_last_n_layers = - freeze_front_n_layers
else :
unfreeze_last_n_layers = num_layers - freeze_front_n_layers
if " unfreeze_last_n_layers " in self . te_freeze_config :
unfreeze_last_n_layers = self . te_freeze_config [ " unfreeze_last_n_layers " ]
2023-05-14 03:49:11 -06:00
2023-06-17 10:54:06 -06:00
if unfreeze_last_n_layers is None :
# nothing specified: default behaviour
unfreeze_last_n_layers = num_layers
2023-05-14 03:49:11 -06:00
else :
2023-06-17 10:54:06 -06:00
# something specified:
assert ( unfreeze_last_n_layers > 0 )
if unfreeze_last_n_layers < num_layers :
# if we're unfreezing layers then by default we ought to freeze the embeddings
unfreeze_embeddings = False
if " freeze_embeddings " in self . te_freeze_config :
unfreeze_embeddings = not self . te_freeze_config [ " freeze_embeddings " ]
if " freeze_final_layer_norm " in self . te_freeze_config :
unfreeze_final_layer_norm = not self . te_freeze_config [ " freeze_final_layer_norm " ]
parameters = itertools . chain ( [ ] )
2023-11-02 19:54:29 -06:00
2023-06-17 10:54:06 -06:00
if unfreeze_embeddings :
2023-05-14 03:49:11 -06:00
parameters = itertools . chain ( parameters , text_encoder . text_model . embeddings . parameters ( ) )
2023-06-17 10:54:06 -06:00
else :
print ( " ❄️ freezing embeddings " )
2023-05-14 03:49:11 -06:00
2023-06-17 10:54:06 -06:00
if unfreeze_last_n_layers > = num_layers :
2023-05-14 03:49:11 -06:00
parameters = itertools . chain ( parameters , text_encoder . text_model . encoder . layers . parameters ( ) )
else :
# freeze the specified CLIP text encoder layers
layers = text_encoder . text_model . encoder . layers
2023-06-17 10:54:06 -06:00
first_layer_to_unfreeze = num_layers - unfreeze_last_n_layers
print ( f " ❄️ freezing text encoder layers 1- { first_layer_to_unfreeze } out of { num_layers } layers total " )
parameters = itertools . chain ( parameters , layers [ first_layer_to_unfreeze : ] . parameters ( ) )
2023-05-14 03:49:11 -06:00
2023-06-17 10:54:06 -06:00
if unfreeze_final_layer_norm :
2023-05-14 03:49:11 -06:00
parameters = itertools . chain ( parameters , text_encoder . text_model . final_layer_norm . parameters ( ) )
2023-06-17 10:54:06 -06:00
else :
print ( " ❄️ freezing final layer norm " )
2023-05-14 03:49:11 -06:00
return parameters
def log_optimizer ( label : str , optimizer : torch . optim . Optimizer , betas , epsilon , weight_decay , lr ) :
2023-04-29 20:56:10 -06:00
"""
logs the optimizer settings
"""
2023-05-14 03:49:11 -06:00
all_params = sum ( [ g [ ' params ' ] for g in optimizer . param_groups ] , [ ] )
frozen_parameter_count = len ( [ p for p in all_params if not p . requires_grad ] )
total_parameter_count = len ( all_params )
if frozen_parameter_count > 0 :
param_info = f " ( { total_parameter_count } parameters, { frozen_parameter_count } frozen) "
else :
param_info = f " ( { total_parameter_count } parameters) "
logging . info ( f " { Fore . CYAN } * { label } optimizer: { optimizer . __class__ . __name__ } { param_info } * { Style . RESET_ALL } " )
2023-04-29 20:56:10 -06:00
logging . info ( f " { Fore . CYAN } lr: { lr } , betas: { betas } , epsilon: { epsilon } , weight_decay: { weight_decay } * { Style . RESET_ALL } " )
2023-05-14 03:49:11 -06:00