chaining and more lowers resolutions

This commit is contained in:
Victor Hall 2023-01-08 18:52:39 -05:00
parent c816e25773
commit 3c921dbaa2
7 changed files with 276 additions and 57 deletions

View File

@ -6,7 +6,7 @@ Welcome to v2.0 of EveryDream trainer! Now with more diffusers and even more fea
Please join us on Discord! https://discord.gg/uheqxU6sXN Please join us on Discord! https://discord.gg/uheqxU6sXN
If you find this tool useful, please consider subscribing to the project on Patreon or buy me a Ko-fi. If you find this tool useful, please consider subscribing to the project on [Patreon](https://www.patreon.com/everydream) or a one-time donation at [Ko-fi](https://ko-fi.com/everydream).
## Video tutorials ## Video tutorials
@ -32,3 +32,5 @@ Behind the scenes look at how the trainer handles multiaspect and crop jitter
[Logging](doc/LOGGING.md) [Logging](doc/LOGGING.md)
[Advanced Tweaking](doc/ATWEAKING.md) [Advanced Tweaking](doc/ATWEAKING.md)
[Chaining training sessions](doc/CHAINING.md)

3
chain.bat Normal file
View File

@ -0,0 +1,3 @@
python train.py --config chain0.json
python train.py --config chain1.json
python train.py --config chain2.json

38
chain0.json Normal file
View File

@ -0,0 +1,38 @@
{
"amp": false,
"batch_size": 12,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 0.00,
"data_root": "R:\\everydream-trainer\\training_samples\\ff7r",
"disable_textenc_training": false,
"disable_xformers": true,
"flip_p": 0.0,
"ed1_mode": true,
"gpuid": 0,
"gradient_checkpointing": true,
"grad_accum": 1,
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 2.5e-6,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 15,
"project_name": "myproj_ch0",
"resolution": 384,
"resume_ckpt": "sd_v1-5_vae",
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
"save_every_n_epochs": 99,
"save_optimizer": false,
"scale_lr": false,
"seed": -1,
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": true
}

38
chain1.json Normal file
View File

@ -0,0 +1,38 @@
{
"amp": false,
"batch_size": 7,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 0.05,
"data_root": "R:\\everydream-trainer\\training_samples\\ff7r",
"disable_textenc_training": false,
"disable_xformers": true,
"flip_p": 0.0,
"ed1_mode": true,
"gpuid": 0,
"gradient_checkpointing": true,
"grad_accum": 1,
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 1.0e-6,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 10,
"project_name": "myproj_ch0",
"resolution": 512,
"resume_ckpt": "findlast",
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
"save_every_n_epochs": 5,
"save_optimizer": false,
"scale_lr": false,
"seed": -1,
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": true
}

38
chain2.json Normal file
View File

@ -0,0 +1,38 @@
{
"amp": false,
"batch_size": 2,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 0.08,
"data_root": "R:\\everydream-trainer\\training_samples\\ff7r",
"disable_textenc_training": true,
"disable_xformers": true,
"flip_p": 0.0,
"ed1_mode": true,
"gpuid": 0,
"gradient_checkpointing": true,
"grad_accum": 5,
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 1.5e-6,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 10,
"project_name": "myproj_ch0",
"resolution": 640,
"resume_ckpt": "findlast",
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
"save_every_n_epochs": 5,
"save_optimizer": false,
"scale_lr": false,
"seed": -1,
"shuffle_tags": false,
"useadam8bit": true,
"wandb": false,
"write_schedule": true
}

View File

@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
ASPECTS11 = [[1152,1152], # 1327104 1:1 ASPECTS_1152 = [[1152,1152], # 1327104 1:1
#[1216,1088],[1088,1216], # 1323008 1.118:1 #[1216,1088],[1088,1216], # 1323008 1.118:1
[1280,1024],[1024,1280], # 1310720 1.25:1 [1280,1024],[1024,1280], # 1310720 1.25:1
[1344,960],[960,1344], # 1290240 1.4:1 [1344,960],[960,1344], # 1290240 1.4:1
@ -25,7 +25,7 @@ ASPECTS11 = [[1152,1152], # 1327104 1:1
[2304,576],[576,2304], # 1327104 4:1 [2304,576],[576,2304], # 1327104 4:1
] ]
ASPECTS10 = [[1088,1088], # 1183744 1:1 ASPECTS_1088 = [[1088,1088], # 1183744 1:1
[1152,1024],[1024,1152], # 1167360 1.267:1 [1152,1024],[1024,1152], # 1167360 1.267:1
[1216,896],[896,1216], # 1146880 1.429:1 [1216,896],[896,1216], # 1146880 1.429:1
[1408,832],[832,1408], # 1171456 1.692:1 [1408,832],[832,1408], # 1171456 1.692:1
@ -36,7 +36,7 @@ ASPECTS10 = [[1088,1088], # 1183744 1:1
[2304,512],[512,2304], # 1179648 4.5:1 [2304,512],[512,2304], # 1179648 4.5:1
] ]
ASPECTS9 = [[1024,1024], # 1048576 1:1 ASPECTS_1024 = [[1024,1024], # 1048576 1:1
#[1088,960],[960,1088], # 1044480 1.125:1 #[1088,960],[960,1088], # 1044480 1.125:1
[1152,896],[896,1152], # 1032192 1.286:1 [1152,896],[896,1152], # 1032192 1.286:1
[1216,832],[832,1216], # 1011712 1.462:1 [1216,832],[832,1216], # 1011712 1.462:1
@ -47,7 +47,7 @@ ASPECTS9 = [[1024,1024], # 1048576 1:1
[2048,512],[512,2048], # 1048576 4:1 [2048,512],[512,2048], # 1048576 4:1
] ]
ASPECTS8 = [[960,960], # 921600 1:1 ASPECTS_960 = [[960,960], # 921600 1:1
[1024,896],[896,1024], # 917504 1.143:1 [1024,896],[896,1024], # 917504 1.143:1
[1088,832],[832,1088], # 905216 1.308:1 [1088,832],[832,1088], # 905216 1.308:1
[1152,768],[768,1152], # 884736 1.5:1 [1152,768],[768,1152], # 884736 1.5:1
@ -59,7 +59,7 @@ ASPECTS8 = [[960,960], # 921600 1:1
[2048,448],[448,2048], # 917504 4.714:1 [2048,448],[448,2048], # 917504 4.714:1
] ]
ASPECTS7 = [[896,896], # 802816 1:1 ASPECTS_896 = [[896,896], # 802816 1:1
[960,832],[832,960], # 798720 1.153:1 [960,832],[832,960], # 798720 1.153:1
[1024,768],[768,1024], # 786432 1.333:1 [1024,768],[768,1024], # 786432 1.333:1
[1088,704],[704,1088], # 765952 1.545:1 [1088,704],[704,1088], # 765952 1.545:1
@ -69,7 +69,7 @@ ASPECTS7 = [[896,896], # 802816 1:1
[1792,448],[448,1792], # 802816 4:1 [1792,448],[448,1792], # 802816 4:1
] ]
ASPECTS6 = [[832,832], # 692224 1:1 ASPECTS_832 = [[832,832], # 692224 1:1
[896,768],[768,896], # 688128 1.167:1 [896,768],[768,896], # 688128 1.167:1
[960,704],[704,960], # 675840 1.364:1 [960,704],[704,960], # 675840 1.364:1
#[960,640],[640,960], # 614400 1.5:1 #[960,640],[640,960], # 614400 1.5:1
@ -82,7 +82,7 @@ ASPECTS6 = [[832,832], # 692224 1:1
[1600,384],[384,1600], # 614400 4.167:1 [1600,384],[384,1600], # 614400 4.167:1
] ]
ASPECTS5 = [[768,768], # 589824 1:1 ASPECTS_768 = [[768,768], # 589824 1:1
[832,704],[704,832], # 585728 1.181:1 [832,704],[704,832], # 585728 1.181:1
[896,640],[640,896], # 573440 1.4:1 [896,640],[640,896], # 573440 1.4:1
[960,576],[576,960], # 552960 1.6:1 [960,576],[576,960], # 552960 1.6:1
@ -96,7 +96,7 @@ ASPECTS5 = [[768,768], # 589824 1:1
[1472,320],[320,1472], # 470400 4.6:1 [1472,320],[320,1472], # 470400 4.6:1
] ]
ASPECTS4 = [[704,704], # 501,376 1:1 ASPECTS_704 = [[704,704], # 501,376 1:1
[768,640],[640,768], # 491,520 1.2:1 [768,640],[640,768], # 491,520 1.2:1
[832,576],[576,832], # 458,752 1.444:1 [832,576],[576,832], # 458,752 1.444:1
#[896,512],[512,896], # 458,752 1.75:1 #[896,512],[512,896], # 458,752 1.75:1
@ -109,7 +109,7 @@ ASPECTS4 = [[704,704], # 501,376 1:1
[1280,320],[320,1280], # 409,600 4:1 [1280,320],[320,1280], # 409,600 4:1
] ]
ASPECTS3 = [[640,640], # 409600 1:1 ASPECTS_640 = [[640,640], # 409600 1:1
[704,576],[576,704], # 405504 1.25:1 [704,576],[576,704], # 405504 1.25:1
[768,512],[512,768], # 393216 1.5:1 [768,512],[512,768], # 393216 1.5:1
[832,448],[448,832], # 372736 1.857:1 [832,448],[448,832], # 372736 1.857:1
@ -119,7 +119,7 @@ ASPECTS3 = [[640,640], # 409600 1:1
[1280,320],[320,1280], # 409600 4:1 [1280,320],[320,1280], # 409600 4:1
] ]
ASPECTS2 = [[576,576], # 331776 1:1 ASPECTS_576 = [[576,576], # 331776 1:1
[640,512],[512,640], # 327680 1.25:1 [640,512],[512,640], # 327680 1.25:1
#[640,448],[448,640], # 286720 1.4286:1 #[640,448],[448,640], # 286720 1.4286:1
[704,448],[448,704], # 314928 1.5625:1 [704,448],[448,704], # 314928 1.5625:1
@ -130,7 +130,7 @@ ASPECTS2 = [[576,576], # 331776 1:1
#[1280,256],[256,1280], # 327680 5:1 #[1280,256],[256,1280], # 327680 5:1
] ]
ASPECTS = [[512,512], # 262144 1:1 ASPECTS_512 = [[512,512], # 262144 1:1
[576,448],[448,576], # 258048 1.29:1 [576,448],[448,576], # 258048 1.29:1
[640,384],[384,640], # 245760 1.667:1 [640,384],[384,640], # 245760 1.667:1
[768,320],[320,768], # 245760 2.4:1 [768,320],[320,768], # 245760 2.4:1
@ -140,14 +140,25 @@ ASPECTS = [[512,512], # 262144 1:1
[1024,256],[256,1024], # 245760 4:1 [1024,256],[256,1024], # 245760 4:1
] ]
ASPECTS0 = [[448,448], # 200704 1:1 ASPECTS_448 = [[448,448], # 200704 1:1
[512,384],[384,512], # 196608 1.333:1 [512,384],[384,512], # 196608 1.333:1
[640,320],[320,640], # 204800 2:1 [640,320],[320,640], # 204800 2:1
[768,256],[256,768], # 196608 3:1 [768,256],[256,768], # 196608 3:1
] ]
ASPECTS_384 = [[384,384], # 147456 1:1
[448,320],[320,448], # 143360 1.4:1
[512,256],[256,512], # 131072 2:1
[704,192],[192,704], # 135168 3.667:1
]
ASPECTS_256 = [[256,256], # 65536 1:1
[384,192],[192,384], # 73728 2:1
[512,128],[128,512], # 65536 4:1
]
def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False): def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
if resolution < 512: if resolution < 256:
raise ValueError("Resolution must be at least 512") raise ValueError("Resolution must be at least 512")
try: try:
rounded_resolution = int(resolution / 64) * 64 rounded_resolution = int(resolution / 64) * 64
@ -164,4 +175,18 @@ def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
raise e raise e
def __get_all_aspects(): def __get_all_aspects():
return [ASPECTS0, ASPECTS, ASPECTS2, ASPECTS3, ASPECTS4, ASPECTS5, ASPECTS6, ASPECTS7, ASPECTS8, ASPECTS9, ASPECTS10, ASPECTS11] return [ASPECTS_256,
ASPECTS_384,
ASPECTS_448,
ASPECTS_512,
ASPECTS_576,
ASPECTS_640,
ASPECTS_704,
ASPECTS_768,
ASPECTS_832,
ASPECTS_896,
ASPECTS_960,
ASPECTS_1024,
ASPECTS_1088,
ASPECTS_1152
]

159
train.py
View File

@ -105,6 +105,7 @@ def setup_local_logger(args):
configures logger with file and console logging, logs args, and returns the datestamp configures logger with file and console logging, logs args, and returns the datestamp
""" """
log_path = args.logdir log_path = args.logdir
if not os.path.exists(log_path): if not os.path.exists(log_path):
os.makedirs(log_path) os.makedirs(log_path)
@ -115,6 +116,7 @@ def setup_local_logger(args):
f.write(f"{json_config}") f.write(f"{json_config}")
logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log") logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log")
print(f" logging to {logfilename}")
logging.basicConfig(filename=logfilename, logging.basicConfig(filename=logfilename,
level=logging.INFO, level=logging.INFO,
format="%(asctime)s %(message)s", format="%(asctime)s %(message)s",
@ -122,6 +124,7 @@ def setup_local_logger(args):
) )
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
return datetimestamp return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon): def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
@ -190,30 +193,51 @@ def set_args_12gb(args):
logging.info(" Overiding adam8bit to True") logging.info(" Overiding adam8bit to True")
args.useadam8bit = True args.useadam8bit = True
def main(args): def find_last_checkpoint(logdir):
""" """
Main entry point Finds the last checkpoint in the logdir, recursively
""" """
log_time = setup_local_logger(args) last_ckpt = None
last_date = None
for root, dirs, files in os.walk(logdir):
for file in files:
if os.path.basename(file) == "model_index.json":
curr_date = os.path.getmtime(os.path.join(root,file))
if last_date is None or curr_date > last_date:
last_date = curr_date
last_ckpt = root
assert last_ckpt, f"Could not find last checkpoint in logdir: {logdir}"
assert "errored" not in last_ckpt, f"Found last checkpoint: {last_ckpt}, but it was errored, cancelling"
print(f" {Fore.LIGHTCYAN_EX}Found last checkpoint: {last_ckpt}, resuming{Style.RESET_ALL}")
return last_ckpt
def setup_args(args):
"""
Sets defaults for missing args (possible if missing from json config)
Forces some args to be set based on others for compatibility reasons
"""
if args.resume_ckpt == "findlast":
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if args.ed1_mode and not args.disable_xformers:
args.disable_xformers = True
logging.info(" ED1 mode: Overiding disable_xformers to True")
if args.lowvram: if args.lowvram:
set_args_12gb(args) set_args_12gb(args)
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = False
if args.ed1_mode:
args.disable_xformers = True
if not args.shuffle_tags: if not args.shuffle_tags:
args.shuffle_tags = False args.shuffle_tags = False
args.clip_skip = max(min(4, args.clip_skip), 0) args.clip_skip = max(min(4, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20 args.ckpt_every_n_minutes = 20
@ -231,16 +255,32 @@ def main(args):
if args.cond_dropout > 0.26: if args.cond_dropout > 0.26:
logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}") logging.warning(f"{Fore.LIGHTYELLOW_EX}** cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
total_batch_size = args.batch_size * args.grad_accum
if args.grad_accum > 1: if args.grad_accum > 1:
logging.info(f"{Fore.CYAN} Batch size: {args.batch_size}, grad accum: {args.grad_accum}, 'effective' batch size: {args.batch_size * args.grad_accum}{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} Batch size: {args.batch_size}, grad accum: {args.grad_accum}, 'effective' batch size: {args.batch_size * args.grad_accum}{Style.RESET_ALL}")
total_batch_size = args.batch_size * args.grad_accum
if args.scale_lr is not None and args.scale_lr: if args.scale_lr is not None and args.scale_lr:
tmp_lr = args.lr tmp_lr = args.lr
args.lr = args.lr * (total_batch_size**0.55) args.lr = args.lr * (total_batch_size**0.55)
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
return args
def main(args):
"""
Main entry point
"""
log_time = setup_local_logger(args)
args = setup_args(args)
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = True
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
logging.info(f"Logging to {log_folder}") logging.info(f"Logging to {log_folder}")
if not os.path.exists(log_folder): if not os.path.exists(log_folder):
@ -409,9 +449,12 @@ def main(args):
default_lr = 3e-6 default_lr = 3e-6
curr_lr = args.lr if args.lr is not None else default_lr curr_lr = args.lr if args.lr is not None else default_lr
# vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
# unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16)
# text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16)
vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16)
unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16) unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16) text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training: if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
@ -537,15 +580,7 @@ def main(args):
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
# scaler = torch.cuda.amp.GradScaler(
# #enabled=False,
# enabled=True if args.amp else False,
# init_scale=2**1,
# growth_factor=1.000001,
# backoff_factor=0.9999999,
# growth_interval=50,
# )
#logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
def collate_fn(batch): def collate_fn(batch):
""" """
@ -607,8 +642,22 @@ def main(args):
#loss = torch.tensor(0.0, device=device, dtype=torch.float32) #loss = torch.tensor(0.0, device=device, dtype=torch.float32)
try: if args.amp:
#scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(
#enabled=False,
enabled=True,
init_scale=1024.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
loss_log_step = []
try:
for epoch in range(args.max_epochs): for epoch in range(args.max_epochs):
loss_epoch = []
epoch_start_time = time.time() epoch_start_time = time.time()
steps_pbar.reset() steps_pbar.reset()
images_per_sec_log_step = [] images_per_sec_log_step = []
@ -619,8 +668,8 @@ def main(args):
with torch.no_grad(): with torch.no_grad():
#with autocast(): #with autocast():
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
with autocast(enabled=args.amp): #with autocast(enabled=args.amp):
latents = vae.encode(pixel_values, return_dict=False) latents = vae.encode(pixel_values, return_dict=False)
del pixel_values del pixel_values
latents = latents[0].sample() * 0.18215 latents = latents[0].sample() * 0.18215
@ -650,8 +699,8 @@ def main(args):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption del noise, latents, cuda_caption
with autocast(enabled=args.amp): #with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp): #with autocast(enabled=args.amp):
@ -663,7 +712,10 @@ def main(args):
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm) torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm) torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
loss.backward() if args.amp:
scaler.scale(loss).backward()
else:
loss.backward()
if batch["runt_size"] > 0: if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size grad_scale = batch["runt_size"] / args.batch_size
@ -677,28 +729,37 @@ def main(args):
param.grad *= grad_scale param.grad *= grad_scale
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1): if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
optimizer.step() if args.amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
lr_scheduler.step() lr_scheduler.step()
steps_pbar.set_postfix({"gs": global_step}) loss_step = loss.detach().item()
steps_pbar.set_postfix({"loss/step": loss_step},{"gs": global_step})
steps_pbar.update(1) steps_pbar.update(1)
global_step += 1
images_per_sec = args.batch_size / (time.time() - step_start_time) images_per_sec = args.batch_size / (time.time() - step_start_time)
images_per_sec_log_step.append(images_per_sec) images_per_sec_log_step.append(images_per_sec)
loss_log_step.append(loss_step)
loss_epoch.append(loss_step)
if (global_step + 1) % args.log_step == 0: if (global_step + 1) % args.log_step == 0:
curr_lr = lr_scheduler.get_last_lr()[0] curr_lr = lr_scheduler.get_last_lr()[0]
loss_local = loss.detach().item() loss_local = sum(loss_log_step) / len(loss_log_step)
logs = {"loss/step": loss_local, "lr": curr_lr, "img/s": images_per_sec} loss_log_step = []
log_writer.add_scalar(tag="loss/step", scalar_value=loss_local, global_step=global_step) logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
sum_img = sum(images_per_sec_log_step) sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step) avg = sum_img / len(images_per_sec_log_step)
images_per_sec_log_step = [] images_per_sec_log_step = []
#log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step) if args.amp:
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step) log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -732,7 +793,8 @@ def main(args):
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir) __save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, args.save_ckpt_dir)
del loss, batch del batch
global_step += 1
# end of step # end of step
elapsed_epoch_time = (time.time() - epoch_start_time) / 60 elapsed_epoch_time = (time.time() - epoch_start_time) / 60
@ -742,6 +804,9 @@ def main(args):
epoch_pbar.update(1) epoch_pbar.update(1)
if epoch < args.max_epochs - 1: if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch+1) train_batch.shuffle(epoch_n=epoch+1)
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# end of epoch # end of epoch
# end of training # end of training
@ -765,6 +830,14 @@ def main(args):
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
def update_old_args(t_args):
"""
Update old args to new args to deal with json config loading and missing args for compatibility
"""
if not hasattr(t_args, "shuffle_tags"):
print(f" Config json is missing 'shuffle_tags'")
t_args.__dict__["shuffle_tags"] = False
if __name__ == "__main__": if __name__ == "__main__":
supported_resolutions = [448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152] supported_resolutions = [448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
argparser = argparse.ArgumentParser(description="EveryDream2 Training options") argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
@ -776,6 +849,8 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
t_args = argparse.Namespace() t_args = argparse.Namespace()
t_args.__dict__.update(json.load(f)) t_args.__dict__.update(json.load(f))
update_old_args(t_args) # update args to support older configs
print(t_args.__dict__)
args = argparser.parse_args(namespace=t_args) args = argparser.parse_args(namespace=t_args)
else: else:
print("No config file specified, using command line args") print("No config file specified, using command line args")