optimizer spltting

This commit is contained in:
Victor Hall 2023-04-29 23:15:48 -04:00
parent 3639e36135
commit 72a47741f0
2 changed files with 8 additions and 10 deletions

View File

@ -23,11 +23,12 @@ class EveryDreamOptimizer():
text_encoder: text encoder model
unet: unet model
"""
def __init__(self, args, optimizer_config, text_encoder_params, unet_params):
def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len):
self.grad_accum = args.grad_accum
self.clip_grad_norm = args.clip_grad_norm
self.text_encoder_params = text_encoder_params
self.unet_params = unet_params
self.epoch_len = epoch_len
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, optimizer_config, text_encoder_params, unet_params)
self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config)
@ -67,19 +68,16 @@ class EveryDreamOptimizer():
self.optimizer_unet.step()
if self.clip_grad_norm is not None:
if not args.disable_unet_training:
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
if not args.disable_textenc_training:
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
if ((global_step + 1) % self.grad_accum == 0) or (step == epoch_len - 1):
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1):
self.scaler.step(self.optimizer_te)
self.scaler.step(self.optimizer_unet)
self.scaler.update()
self._zero_grad(set_to_none=True)
self.lr_scheduler.step()
self.optimizer_unet.step()
self.lr_scheduler_unet.step()
self.lr_scheduler_te.step()
self.update_grad_scaler(global_step)
def _zero_grad(self, set_to_none=False):

View File

@ -549,7 +549,7 @@ def main(args):
epoch_len = math.ceil(len(train_batch) / args.batch_size)
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters())
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len)
log_args(log_writer, args)