Fix push_to_hub for dreambooth and textual_inversion (#748)
* Fix push_to_hub for dreambooth and textual_inversion * Use repo.push_to_hub instead of push_to_hub
This commit is contained in:
parent
7258dc4943
commit
906e4105d7
|
@ -575,9 +575,7 @@ def main():
|
|||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(
|
||||
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
|
|
@ -569,9 +569,7 @@ def main():
|
|||
save_progress(text_encoder, placeholder_token_id, accelerator, args)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(
|
||||
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from accelerate import Accelerator
|
|||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.hub_utils import init_git_repo
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from torchvision.transforms import (
|
||||
|
@ -185,7 +185,7 @@ def main(args):
|
|||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
|
|
Loading…
Reference in New Issue