optimizer spltting
This commit is contained in:
parent
3639e36135
commit
72a47741f0
|
@ -23,11 +23,12 @@ class EveryDreamOptimizer():
|
||||||
text_encoder: text encoder model
|
text_encoder: text encoder model
|
||||||
unet: unet model
|
unet: unet model
|
||||||
"""
|
"""
|
||||||
def __init__(self, args, optimizer_config, text_encoder_params, unet_params):
|
def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len):
|
||||||
self.grad_accum = args.grad_accum
|
self.grad_accum = args.grad_accum
|
||||||
self.clip_grad_norm = args.clip_grad_norm
|
self.clip_grad_norm = args.clip_grad_norm
|
||||||
self.text_encoder_params = text_encoder_params
|
self.text_encoder_params = text_encoder_params
|
||||||
self.unet_params = unet_params
|
self.unet_params = unet_params
|
||||||
|
self.epoch_len = epoch_len
|
||||||
|
|
||||||
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, optimizer_config, text_encoder_params, unet_params)
|
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, optimizer_config, text_encoder_params, unet_params)
|
||||||
self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config)
|
self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config)
|
||||||
|
@ -67,19 +68,16 @@ class EveryDreamOptimizer():
|
||||||
self.optimizer_unet.step()
|
self.optimizer_unet.step()
|
||||||
|
|
||||||
if self.clip_grad_norm is not None:
|
if self.clip_grad_norm is not None:
|
||||||
if not args.disable_unet_training:
|
|
||||||
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
||||||
if not args.disable_textenc_training:
|
|
||||||
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
||||||
if ((global_step + 1) % self.grad_accum == 0) or (step == epoch_len - 1):
|
if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1):
|
||||||
self.scaler.step(self.optimizer_te)
|
self.scaler.step(self.optimizer_te)
|
||||||
self.scaler.step(self.optimizer_unet)
|
self.scaler.step(self.optimizer_unet)
|
||||||
self.scaler.update()
|
self.scaler.update()
|
||||||
self._zero_grad(set_to_none=True)
|
self._zero_grad(set_to_none=True)
|
||||||
|
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler_unet.step()
|
||||||
|
self.lr_scheduler_te.step()
|
||||||
self.optimizer_unet.step()
|
|
||||||
self.update_grad_scaler(global_step)
|
self.update_grad_scaler(global_step)
|
||||||
|
|
||||||
def _zero_grad(self, set_to_none=False):
|
def _zero_grad(self, set_to_none=False):
|
||||||
|
|
2
train.py
2
train.py
|
@ -549,7 +549,7 @@ def main(args):
|
||||||
|
|
||||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||||
|
|
||||||
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters())
|
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len)
|
||||||
|
|
||||||
log_args(log_writer, args)
|
log_args(log_writer, args)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue