optimizer spltting
This commit is contained in:
parent
3639e36135
commit
72a47741f0
|
@ -23,11 +23,12 @@ class EveryDreamOptimizer():
|
|||
text_encoder: text encoder model
|
||||
unet: unet model
|
||||
"""
|
||||
def __init__(self, args, optimizer_config, text_encoder_params, unet_params):
|
||||
def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len):
|
||||
self.grad_accum = args.grad_accum
|
||||
self.clip_grad_norm = args.clip_grad_norm
|
||||
self.text_encoder_params = text_encoder_params
|
||||
self.unet_params = unet_params
|
||||
self.epoch_len = epoch_len
|
||||
|
||||
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, optimizer_config, text_encoder_params, unet_params)
|
||||
self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config)
|
||||
|
@ -67,19 +68,16 @@ class EveryDreamOptimizer():
|
|||
self.optimizer_unet.step()
|
||||
|
||||
if self.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
||||
if not args.disable_textenc_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
||||
if ((global_step + 1) % self.grad_accum == 0) or (step == epoch_len - 1):
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
||||
if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1):
|
||||
self.scaler.step(self.optimizer_te)
|
||||
self.scaler.step(self.optimizer_unet)
|
||||
self.scaler.update()
|
||||
self._zero_grad(set_to_none=True)
|
||||
|
||||
self.lr_scheduler.step()
|
||||
|
||||
self.optimizer_unet.step()
|
||||
self.lr_scheduler_unet.step()
|
||||
self.lr_scheduler_te.step()
|
||||
self.update_grad_scaler(global_step)
|
||||
|
||||
def _zero_grad(self, set_to_none=False):
|
||||
|
|
2
train.py
2
train.py
|
@ -549,7 +549,7 @@ def main(args):
|
|||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters())
|
||||
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len)
|
||||
|
||||
log_args(log_writer, args)
|
||||
|
||||
|
|
Loading…
Reference in New Issue