trainer ymls and fix main

This commit is contained in:
Victor Hall 2022-10-24 22:58:36 -04:00
parent fbbd370661
commit 09d67853f0
3 changed files with 31 additions and 39 deletions

View File

@ -1,5 +1,5 @@
model:
base_learning_rate: 9.0e-07
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
reg_weight: 1.0
@ -20,15 +20,6 @@ model:
embedding_reg_weight: 0.0
unfreeze_model: True
model_lr: 5.0e-7
# scheduler_config:
# target: ldm.lr_scheduler.LambdaLinearScheduler
# params:
# verbosity_interval: 200
# warm_up_steps: 5
# max_decay_steps: 100
# lr_start: 6.0e-7
# lr_max: 8.0e-7
# lr_min: 1.0e-7
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
@ -93,7 +84,7 @@ data:
params:
size: 512
set: train
repeats: 5
repeats: 3
validation:
target: ldm.data.personalized.PersonalizedBase
params:
@ -116,6 +107,6 @@ lightning:
trainer:
benchmark: True
max_epochs: 3
#precision: 16 # need lightning 1.6+
#precision: 16 # need lightning 1.6+ ??
#num_nodes: 2 # for multigpu
#check_val_every_n_epoch: 1
#check_val_every_n_epoch: 2

View File

@ -17,9 +17,9 @@ model:
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
#embedding_reg_weight: 0.0
unfreeze_model: True
model_lr: 5.0e-7
model_lr: 6.0e-7
# scheduler_config:
# target: ldm.lr_scheduler.LambdaLinearScheduler
# params:
@ -30,14 +30,14 @@ model:
# lr_max: 8.0e-7
# lr_min: 1.0e-7
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
# personalization_config:
# target: ldm.modules.embedding_manager.EmbeddingManager
# params:
# placeholder_strings: ["*"]
# initializer_words: ["sculpture"]
# per_image_tokens: false
# num_vectors_per_token: 1
# progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
@ -109,13 +109,13 @@ lightning:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 300
max_images: 16
batch_frequency: 500
max_images: 12
increase_log_steps: False
trainer:
benchmark: True
max_epochs: 3
max_epochs: 5
#precision: 16 # need lightning 1.6+
#num_nodes: 2 # for multigpu
#check_val_every_n_epoch: 1

27
main.py
View File

@ -153,7 +153,7 @@ def get_parser(**parser_kwargs):
"--max_training_steps",
type=int,
required=False,
default=9400,
default=35000,
help="Number of iterations to run")
parser.add_argument("--actual_resume",
@ -605,21 +605,22 @@ if __name__ == "__main__":
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
# default_modelckpt_cfg = {
# "target": "pytorch_lightning.callbacks.ModelCheckpoint",
# "params": {
# "dirpath": ckptdir,
# "filename": "{epoch:03}",
# "verbose": True,
# "save_last": True,
# }
# }
#modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
#specify which metric is used to determine best models
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:03}",
"verbose": True,
"save_last": True,
}
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 1
default_modelckpt_cfg["params"]["save_top_k"] = 3
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint