Fix push_to_hub for dreambooth and textual_inversion (#748)

* Fix push_to_hub for dreambooth and textual_inversion

* Use repo.push_to_hub instead of push_to_hub
This commit is contained in:
YaYaB 2022-10-07 11:50:28 +02:00 committed by GitHub
parent 7258dc4943
commit 906e4105d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 8 deletions

View File

@ -575,9 +575,7 @@ def main():
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()

View File

@ -569,9 +569,7 @@ def main():
save_progress(text_encoder, placeholder_token_id, accelerator, args)
if args.push_to_hub:
repo.push_to_hub(
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()

View File

@ -9,7 +9,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.hub_utils import init_git_repo
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from torchvision.transforms import (
@ -185,7 +185,7 @@ def main(args):
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()