additional hub arguments

This commit is contained in:
anton-l 2022-06-21 11:21:10 +02:00
parent 9c82c32ba7
commit 0417baf23d
2 changed files with 28 additions and 19 deletions

View File

@ -74,7 +74,8 @@ def main(args):
repo = init_git_repo(args, at_init=True)
# Train!
world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
@ -120,17 +121,14 @@ def main(args):
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
if is_distributed:
torch.distributed.barrier()
# Generate a sample image for visual inspection
torch.distributed.barrier()
if args.local_rank in [-1, 0]:
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler)
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_path)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
@ -143,9 +141,16 @@ def main(args):
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
test_dir = os.path.join(args.output_path, "test_samples")
test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch}.png")
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
if is_distributed:
torch.distributed.barrier()
@ -153,14 +158,18 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-model")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--output_path", type=str, default="ddpm-model")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument(
"--mixed_precision",
type=str,

View File

@ -70,7 +70,7 @@ def init_git_repo(args, at_init: bool = False):
repo.git_pull()
# By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(args.output_dir, ".gitignore")) and args.hub_strategy != "all_checkpoints":
if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"])