more wip optimizer splitting
This commit is contained in:
parent
72a47741f0
commit
970065c206
|
@ -16,7 +16,7 @@
|
|||
"unet": {
|
||||
"optimizer": "adamw8bit",
|
||||
"lr": 1e-6,
|
||||
"lr_scheduler": null,
|
||||
"lr_scheduler": "constant",
|
||||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.010
|
||||
|
|
|
@ -24,6 +24,7 @@ class EveryDreamOptimizer():
|
|||
unet: unet model
|
||||
"""
|
||||
def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len):
|
||||
print(f"\noptimizer_config: \n{optimizer_config}\n")
|
||||
self.grad_accum = args.grad_accum
|
||||
self.clip_grad_norm = args.clip_grad_norm
|
||||
self.text_encoder_params = text_encoder_params
|
||||
|
@ -34,14 +35,6 @@ class EveryDreamOptimizer():
|
|||
self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config)
|
||||
|
||||
self.unet_config = optimizer_config.get("unet", {})
|
||||
if args.lr is not None:
|
||||
self.unet_config["lr"] = args.lr
|
||||
self.te_config = optimizer_config.get("text_encoder", {})
|
||||
if self.te_config.get("lr", None) is None:
|
||||
self.te_config["lr"] = self.unet_config["lr"]
|
||||
te_scale = self.optimizer_config.get("text_encoder_lr_scale", None)
|
||||
if te_scale is not None:
|
||||
self.te_config["lr"] = self.unet_config["lr"] * te_scale
|
||||
|
||||
optimizer_te_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_TE_STATE_FILENAME)
|
||||
optimizer_unet_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_UNET_STATE_FILENAME)
|
||||
|
@ -100,37 +93,72 @@ class EveryDreamOptimizer():
|
|||
self._save_optimizer(self.optimizer_te, os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME))
|
||||
self._save_optimizer(self.optimizer_unet, os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME))
|
||||
|
||||
def create_optimizers(self, args, global_optimizer_config, text_encoder, unet):
|
||||
def create_optimizers(self, args, global_optimizer_config, text_encoder_params, unet_params):
|
||||
"""
|
||||
creates optimizers from config and argsfor unet and text encoder
|
||||
returns (optimizer_te, optimizer_unet)
|
||||
"""
|
||||
text_encoder_lr_scale = global_optimizer_config.get("text_encoder_lr_scale")
|
||||
unet_config = global_optimizer_config.get("unet")
|
||||
te_config = global_optimizer_config.get("text_encoder")
|
||||
te_config, unet_config = self.fold(te_config=te_config,
|
||||
unet_config=unet_config,
|
||||
text_encoder_lr_scale=text_encoder_lr_scale)
|
||||
|
||||
if args.disable_textenc_training:
|
||||
optimizer_te = create_null_optimizer()
|
||||
else:
|
||||
optimizer_te = self.create_optimizer(global_optimizer_config.get("text_encoder"), text_encoder)
|
||||
optimizer_te = self.create_optimizer(args, te_config, text_encoder_params)
|
||||
if args.disable_unet_training:
|
||||
optimizer_unet = create_null_optimizer()
|
||||
else:
|
||||
optimizer_unet = self.create_optimizer(global_optimizer_config, unet)
|
||||
optimizer_unet = self.create_optimizer(args, unet_config, unet_params)
|
||||
|
||||
return optimizer_te, optimizer_unet
|
||||
|
||||
@staticmethod
|
||||
def fold(te_config, unet_config, text_encoder_lr_scale):
|
||||
"""
|
||||
defaults text encoder config values to unet config values if not specified per property
|
||||
"""
|
||||
if te_config.get("optimizer", None) is None:
|
||||
te_config["optimizer"] = unet_config["optimizer"]
|
||||
if te_config.get("lr", None) is None:
|
||||
te_config["lr"] = unet_config["lr"]
|
||||
te_scale = text_encoder_lr_scale
|
||||
if te_scale is not None:
|
||||
te_config["lr"] = unet_config["lr"] * te_scale
|
||||
if te_config.get("weight_decay", None) is None:
|
||||
te_config["weight_decay"] = unet_config["weight_decay"]
|
||||
if te_config.get("betas", None) is None:
|
||||
te_config["betas"] = unet_config["betas"]
|
||||
if te_config.get("epsilon", None) is None:
|
||||
te_config["epsilon"] = unet_config["epsilon"]
|
||||
if te_config.get("lr_scheduler", None) is None:
|
||||
te_config["lr_scheduler"] = unet_config["lr_scheduler"]
|
||||
|
||||
return te_config, unet_config
|
||||
|
||||
def create_lr_schedulers(self, args, optimizer_config):
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
lr_scheduler_type_te = optimizer_config.get("lr_scheduler", self.unet_config.lr_scheduler)
|
||||
lr_scheduler_type_unet = optimizer_config["unet"].get("lr_scheduler", None)
|
||||
assert lr_scheduler_type_unet is not None, "lr_scheduler must be specified in optimizer config"
|
||||
lr_scheduler_type_te = optimizer_config.get("lr_scheduler", lr_scheduler_type_unet)
|
||||
|
||||
self.lr_scheduler_te = get_scheduler(
|
||||
lr_scheduler_type_te,
|
||||
optimizer=self.optimizer_te,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
|
||||
self.lr_scheduler_unet = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
lr_scheduler_type_unet,
|
||||
optimizer=self.optimizer_unet,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
|
||||
return self.lr_scheduler_te, self.lr_scheduler_unet
|
||||
|
||||
def update_grad_scaler(self, global_step):
|
||||
|
@ -169,8 +197,8 @@ class EveryDreamOptimizer():
|
|||
"""
|
||||
optimizer.load_state_dict(torch.load(path))
|
||||
|
||||
@staticmethod
|
||||
def create_optimizer(args, local_optimizer_config, parameters):
|
||||
def create_optimizer(self, args, local_optimizer_config, parameters):
|
||||
print(f"Creating optimizer from {local_optimizer_config}")
|
||||
betas = BETAS_DEFAULT
|
||||
epsilon = EPSILON_DEFAULT
|
||||
weight_decay = WEIGHT_DECAY_DEFAULT
|
||||
|
@ -182,10 +210,10 @@ class EveryDreamOptimizer():
|
|||
text_encoder_lr_scale = 1.0
|
||||
|
||||
if local_optimizer_config is not None:
|
||||
betas = local_optimizer_config["betas"]
|
||||
epsilon = local_optimizer_config["epsilon"]
|
||||
weight_decay = local_optimizer_config["weight_decay"]
|
||||
optimizer_name = local_optimizer_config["optimizer"]
|
||||
betas = local_optimizer_config["betas"] or betas
|
||||
epsilon = local_optimizer_config["epsilon"] or epsilon
|
||||
weight_decay = local_optimizer_config["weight_decay"] or weight_decay
|
||||
optimizer_name = local_optimizer_config["optimizer"] or None
|
||||
curr_lr = local_optimizer_config.get("lr", curr_lr)
|
||||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
|
@ -228,18 +256,10 @@ class EveryDreamOptimizer():
|
|||
)
|
||||
|
||||
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
|
||||
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5)
|
||||
args.lr_decay_steps = int(self.epoch_len * args.max_epochs * 1.5)
|
||||
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||
return optimizer
|
||||
|
||||
|
|
13
train.py
13
train.py
|
@ -261,11 +261,6 @@ def setup_args(args):
|
|||
|
||||
total_batch_size = args.batch_size * args.grad_accum
|
||||
|
||||
if args.scale_lr is not None and args.scale_lr:
|
||||
tmp_lr = args.lr
|
||||
args.lr = args.lr * (total_batch_size**0.55)
|
||||
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
|
||||
|
||||
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
|
||||
os.makedirs(args.save_ckpt_dir)
|
||||
|
||||
|
@ -550,7 +545,7 @@ def main(args):
|
|||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len)
|
||||
|
||||
exit()
|
||||
log_args(log_writer, args)
|
||||
|
||||
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
|
||||
|
@ -864,7 +859,7 @@ if __name__ == "__main__":
|
|||
print("No config file specified, using command line args")
|
||||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
#argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
|
@ -884,7 +879,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
|
||||
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
|
||||
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
|
||||
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
#argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
|
||||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
|
||||
|
@ -899,7 +894,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--save_ckpts_from_n_epochs", type=int, default=0, help="Only saves checkpoints starting an N epochs, def: 0 (disabled)")
|
||||
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
|
||||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
#argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
def check_git():
|
||||
import subprocess
|
||||
try:
|
||||
result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True)
|
||||
branch = result.stdout.strip()
|
||||
|
||||
result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True)
|
||||
branch = result.stdout.strip()
|
||||
result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True)
|
||||
ahead, behind = map(int, result.stdout.split())
|
||||
|
||||
result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True)
|
||||
ahead, behind = map(int, result.stdout.split())
|
||||
|
||||
if behind > 0:
|
||||
print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.")
|
||||
elif ahead > 0:
|
||||
print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.")
|
||||
else:
|
||||
print(f"** Your branch '{branch}' is up to date with the remote")
|
||||
if behind > 0:
|
||||
print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.")
|
||||
elif ahead > 0:
|
||||
print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.")
|
||||
else:
|
||||
print(f"** Your branch '{branch}' is up to date with the remote")
|
||||
except:
|
||||
pass
|
Loading…
Reference in New Issue