fix log_writer bug and move logs into specific project log folder
This commit is contained in:
parent
bf3c022489
commit
6727b6d61f
50
train.py
50
train.py
|
@ -189,7 +189,7 @@ def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, s
|
||||||
pipeline_ema.save_pretrained(diffusers_model_path)
|
pipeline_ema.save_pretrained(diffusers_model_path)
|
||||||
|
|
||||||
if save_ckpt:
|
if save_ckpt:
|
||||||
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
|
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.safetensors"
|
||||||
|
|
||||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
|
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, s
|
||||||
pipeline.save_pretrained(diffusers_model_path)
|
pipeline.save_pretrained(diffusers_model_path)
|
||||||
|
|
||||||
if save_ckpt:
|
if save_ckpt:
|
||||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
sd_ckpt_path = f"{os.path.basename(save_path)}.safetensors"
|
||||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
|
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
|
||||||
|
|
||||||
if save_optimizer_flag:
|
if save_optimizer_flag:
|
||||||
|
@ -223,17 +223,15 @@ def setup_local_logger(args):
|
||||||
configures logger with file and console logging, logs args, and returns the datestamp
|
configures logger with file and console logging, logs args, and returns the datestamp
|
||||||
"""
|
"""
|
||||||
log_path = args.logdir
|
log_path = args.logdir
|
||||||
|
os.makedirs(log_path, exist_ok=True)
|
||||||
|
|
||||||
if not os.path.exists(log_path):
|
|
||||||
os.makedirs(log_path)
|
|
||||||
|
|
||||||
json_config = json.dumps(vars(args), indent=2)
|
|
||||||
datetimestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
datetimestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
|
||||||
with open(os.path.join(log_path, f"{args.project_name}-{datetimestamp}_cfg.json"), "w") as f:
|
log_folder = os.path.join(log_path, f"{args.project_name}-{datetimestamp}")
|
||||||
f.write(f"{json_config}")
|
os.makedirs(log_folder, exist_ok=True)
|
||||||
|
|
||||||
|
logfilename = os.path.join(log_folder, f"{args.project_name}-{datetimestamp}.log")
|
||||||
|
|
||||||
logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log")
|
|
||||||
print(f" logging to {logfilename}")
|
print(f" logging to {logfilename}")
|
||||||
logging.basicConfig(filename=logfilename,
|
logging.basicConfig(filename=logfilename,
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
|
@ -247,7 +245,7 @@ def setup_local_logger(args):
|
||||||
warnings.filterwarnings("ignore", message="UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images")
|
warnings.filterwarnings("ignore", message="UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images")
|
||||||
#from PIL import Image
|
#from PIL import Image
|
||||||
|
|
||||||
return datetimestamp
|
return datetimestamp, log_folder
|
||||||
|
|
||||||
# def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
# def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||||
# """
|
# """
|
||||||
|
@ -462,8 +460,7 @@ def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
|
||||||
|
|
||||||
return image_train_items
|
return image_train_items
|
||||||
|
|
||||||
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
|
def write_batch_schedule(log_folder: str, train_batch: EveryDreamBatch, epoch: int):
|
||||||
if args.write_schedule:
|
|
||||||
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
||||||
for i in range(len(train_batch.image_train_items)):
|
for i in range(len(train_batch.image_train_items)):
|
||||||
try:
|
try:
|
||||||
|
@ -480,12 +477,22 @@ def read_sample_prompts(sample_prompts_file_path: str):
|
||||||
sample_prompts.append(line.strip())
|
sample_prompts.append(line.strip())
|
||||||
return sample_prompts
|
return sample_prompts
|
||||||
|
|
||||||
def log_args(log_writer, args):
|
|
||||||
|
def log_args(log_writer, args, optimizer_config, log_folder, log_time):
|
||||||
arglog = "args:\n"
|
arglog = "args:\n"
|
||||||
for arg, value in sorted(vars(args).items()):
|
for arg, value in sorted(vars(args).items()):
|
||||||
arglog += f"{arg}={value}, "
|
arglog += f"{arg}={value}, "
|
||||||
log_writer.add_text("config", arglog)
|
log_writer.add_text("config", arglog)
|
||||||
|
|
||||||
|
args_as_json = json.dumps(vars(args), indent=2)
|
||||||
|
with open(os.path.join(log_folder, f"{args.project_name}-{log_time}_main.json"), "w") as f:
|
||||||
|
f.write(args_as_json)
|
||||||
|
|
||||||
|
optimizer_config_as_json = json.dumps(optimizer_config, indent=2)
|
||||||
|
with open(os.path.join(log_folder, f"{args.project_name}-{log_time}_opt.json"), "w") as f:
|
||||||
|
f.write(optimizer_config_as_json)
|
||||||
|
|
||||||
|
|
||||||
def update_ema(model, ema_model, decay, default_device, ema_device):
|
def update_ema(model, ema_model, decay, default_device, ema_device):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
original_model_on_proper_device = model
|
original_model_on_proper_device = model
|
||||||
|
@ -563,7 +570,7 @@ def main(args):
|
||||||
print(" * Windows detected, disabling Triton")
|
print(" * Windows detected, disabling Triton")
|
||||||
os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = "1"
|
os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = "1"
|
||||||
|
|
||||||
log_time = setup_local_logger(args)
|
log_time, log_folder = setup_local_logger(args)
|
||||||
args = setup_args(args)
|
args = setup_args(args)
|
||||||
print(f" Args:")
|
print(f" Args:")
|
||||||
pprint.pprint(vars(args))
|
pprint.pprint(vars(args))
|
||||||
|
@ -582,8 +589,7 @@ def main(args):
|
||||||
device = 'cpu'
|
device = 'cpu'
|
||||||
gpu = None
|
gpu = None
|
||||||
|
|
||||||
|
#log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
|
||||||
|
|
||||||
if not os.path.exists(log_folder):
|
if not os.path.exists(log_folder):
|
||||||
os.makedirs(log_folder)
|
os.makedirs(log_folder)
|
||||||
|
@ -706,8 +712,6 @@ def main(args):
|
||||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if use_ema_dacay_training:
|
if use_ema_dacay_training:
|
||||||
if not ema_model_loaded_from_file:
|
if not ema_model_loaded_from_file:
|
||||||
logging.info(f"EMA decay enabled, creating EMA model.")
|
logging.info(f"EMA decay enabled, creating EMA model.")
|
||||||
|
@ -821,9 +825,10 @@ def main(args):
|
||||||
optimizer_config,
|
optimizer_config,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
unet,
|
unet,
|
||||||
epoch_len)
|
epoch_len,
|
||||||
|
log_writer)
|
||||||
|
|
||||||
log_args(log_writer, args)
|
log_args(log_writer, args, optimizer_config, log_folder, log_time)
|
||||||
|
|
||||||
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
|
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
|
||||||
default_resolution=args.resolution, default_seed=args.seed,
|
default_resolution=args.resolution, default_seed=args.seed,
|
||||||
|
@ -857,7 +862,6 @@ def main(args):
|
||||||
if not interrupted:
|
if not interrupted:
|
||||||
interrupted=True
|
interrupted=True
|
||||||
global global_step
|
global global_step
|
||||||
#TODO: save model on ctrl-c
|
|
||||||
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
|
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
|
||||||
print()
|
print()
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||||
|
@ -1104,11 +1108,10 @@ def main(args):
|
||||||
|
|
||||||
epoch = None
|
epoch = None
|
||||||
try:
|
try:
|
||||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
|
||||||
plugin_runner.run_on_training_start(log_folder=log_folder, project_name=args.project_name)
|
plugin_runner.run_on_training_start(log_folder=log_folder, project_name=args.project_name)
|
||||||
|
|
||||||
for epoch in range(args.max_epochs):
|
for epoch in range(args.max_epochs):
|
||||||
|
write_batch_schedule(log_folder, train_batch, epoch) if args.write_schedule else None
|
||||||
if args.load_settings_every_epoch:
|
if args.load_settings_every_epoch:
|
||||||
load_train_json_from_file(args)
|
load_train_json_from_file(args)
|
||||||
|
|
||||||
|
@ -1269,7 +1272,6 @@ def main(args):
|
||||||
epoch_pbar.update(1)
|
epoch_pbar.update(1)
|
||||||
if epoch < args.max_epochs - 1:
|
if epoch < args.max_epochs - 1:
|
||||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||||
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
|
|
||||||
|
|
||||||
if len(loss_epoch) > 0:
|
if len(loss_epoch) > 0:
|
||||||
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
||||||
|
|
Loading…
Reference in New Issue