trainer ymls and fix main
This commit is contained in:
parent
fbbd370661
commit
09d67853f0
|
@ -1,5 +1,5 @@
|
|||
model:
|
||||
base_learning_rate: 9.0e-07
|
||||
base_learning_rate: 1.0e-06
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
reg_weight: 1.0
|
||||
|
@ -20,15 +20,6 @@ model:
|
|||
embedding_reg_weight: 0.0
|
||||
unfreeze_model: True
|
||||
model_lr: 5.0e-7
|
||||
# scheduler_config:
|
||||
# target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
# params:
|
||||
# verbosity_interval: 200
|
||||
# warm_up_steps: 5
|
||||
# max_decay_steps: 100
|
||||
# lr_start: 6.0e-7
|
||||
# lr_max: 8.0e-7
|
||||
# lr_min: 1.0e-7
|
||||
|
||||
personalization_config:
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
|
@ -93,7 +84,7 @@ data:
|
|||
params:
|
||||
size: 512
|
||||
set: train
|
||||
repeats: 5
|
||||
repeats: 3
|
||||
validation:
|
||||
target: ldm.data.personalized.PersonalizedBase
|
||||
params:
|
||||
|
@ -116,6 +107,6 @@ lightning:
|
|||
trainer:
|
||||
benchmark: True
|
||||
max_epochs: 3
|
||||
#precision: 16 # need lightning 1.6+
|
||||
#precision: 16 # need lightning 1.6+ ??
|
||||
#num_nodes: 2 # for multigpu
|
||||
#check_val_every_n_epoch: 1
|
||||
#check_val_every_n_epoch: 2
|
||||
|
|
|
@ -17,9 +17,9 @@ model:
|
|||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
embedding_reg_weight: 0.0
|
||||
#embedding_reg_weight: 0.0
|
||||
unfreeze_model: True
|
||||
model_lr: 5.0e-7
|
||||
model_lr: 6.0e-7
|
||||
# scheduler_config:
|
||||
# target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
# params:
|
||||
|
@ -30,14 +30,14 @@ model:
|
|||
# lr_max: 8.0e-7
|
||||
# lr_min: 1.0e-7
|
||||
|
||||
personalization_config:
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ["sculpture"]
|
||||
per_image_tokens: false
|
||||
num_vectors_per_token: 1
|
||||
progressive_words: False
|
||||
# personalization_config:
|
||||
# target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
# params:
|
||||
# placeholder_strings: ["*"]
|
||||
# initializer_words: ["sculpture"]
|
||||
# per_image_tokens: false
|
||||
# num_vectors_per_token: 1
|
||||
# progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
|
@ -109,13 +109,13 @@ lightning:
|
|||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 300
|
||||
max_images: 16
|
||||
batch_frequency: 500
|
||||
max_images: 12
|
||||
increase_log_steps: False
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
max_epochs: 3
|
||||
max_epochs: 5
|
||||
#precision: 16 # need lightning 1.6+
|
||||
#num_nodes: 2 # for multigpu
|
||||
#check_val_every_n_epoch: 1
|
||||
|
|
27
main.py
27
main.py
|
@ -153,7 +153,7 @@ def get_parser(**parser_kwargs):
|
|||
"--max_training_steps",
|
||||
type=int,
|
||||
required=False,
|
||||
default=9400,
|
||||
default=35000,
|
||||
help="Number of iterations to run")
|
||||
|
||||
parser.add_argument("--actual_resume",
|
||||
|
@ -605,21 +605,22 @@ if __name__ == "__main__":
|
|||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
||||
|
||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||
# specify which metric is used to determine best models
|
||||
# default_modelckpt_cfg = {
|
||||
# "target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
# "params": {
|
||||
# "dirpath": ckptdir,
|
||||
# "filename": "{epoch:03}",
|
||||
# "verbose": True,
|
||||
# "save_last": True,
|
||||
# }
|
||||
# }
|
||||
#modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||
#specify which metric is used to determine best models
|
||||
default_modelckpt_cfg = {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch:03}",
|
||||
"verbose": True,
|
||||
"save_last": True,
|
||||
}
|
||||
}
|
||||
|
||||
if hasattr(model, "monitor"):
|
||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
||||
default_modelckpt_cfg["params"]["save_top_k"] = 1
|
||||
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
||||
|
||||
if "modelcheckpoint" in lightning_config:
|
||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
||||
|
|
Loading…
Reference in New Issue