Add --pretrained_model_name_revision option to train_dreambooth.py (#933)

* Add --pretrained_model_name_revision option to train_dreambooth.py

* Renamed --pretrained_model_name_revision to --revision
This commit is contained in:
Yuta Hayashibe 2022-10-26 04:38:23 +09:00 committed by GitHub
parent e2243de5f2
commit 4b9f58952a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 6 deletions

View File

@ -35,6 +35,13 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@ -344,7 +351,10 @@ def main():
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
)
pipeline.set_progress_bar_config(disable=True)
@ -390,14 +400,33 @@ def main():
# Load the tokenizer
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
tokenizer = CLIPTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
)
elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
vae.requires_grad_(False)
if not args.train_text_encoder:
@ -613,6 +642,7 @@ def main():
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)