[examples] misc fixes (#1886)
* misc fixes * more comments * Update examples/textual_inversion/textual_inversion.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * set transformers verbosity to warning Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
423c3a4cc6
commit
fa1f4701e8
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import hashlib
|
import hashlib
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -12,6 +13,9 @@ import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import diffusers
|
||||||
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
@ -236,6 +240,24 @@ def parse_args(input_args=None):
|
||||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow_tf32",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||||
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--report_to",
|
||||||
|
type=str,
|
||||||
|
default="tensorboard",
|
||||||
|
help=(
|
||||||
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||||
|
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||||
|
"Only applicable when `--with_tracking` is passed."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mixed_precision",
|
"--mixed_precision",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -422,7 +444,7 @@ def main(args):
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
mixed_precision=args.mixed_precision,
|
mixed_precision=args.mixed_precision,
|
||||||
log_with="tensorboard",
|
log_with=args.report_to,
|
||||||
logging_dir=logging_dir,
|
logging_dir=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -435,9 +457,27 @@ def main(args):
|
||||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
logger.info(accelerator.state, main_process_only=False)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
transformers.utils.logging.set_verbosity_warning()
|
||||||
|
diffusers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
diffusers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# If passed along, set the training seed now.
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
# Generate class images if prior preservation is enabled.
|
||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
class_images_dir = Path(args.class_data_dir)
|
class_images_dir = Path(args.class_data_dir)
|
||||||
if not class_images_dir.exists():
|
if not class_images_dir.exists():
|
||||||
|
@ -502,11 +542,7 @@ def main(args):
|
||||||
|
|
||||||
# Load the tokenizer
|
# Load the tokenizer
|
||||||
if args.tokenizer_name:
|
if args.tokenizer_name:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
|
||||||
args.tokenizer_name,
|
|
||||||
revision=args.revision,
|
|
||||||
use_fast=False,
|
|
||||||
)
|
|
||||||
elif args.pretrained_model_name_or_path:
|
elif args.pretrained_model_name_or_path:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
|
@ -518,38 +554,36 @@ def main(args):
|
||||||
# import correct text encoder class
|
# import correct text encoder class
|
||||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||||
|
|
||||||
# Load models and create wrapper for stable diffusion
|
# Load scheduler and models
|
||||||
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
text_encoder = text_encoder_cls.from_pretrained(
|
text_encoder = text_encoder_cls.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||||
subfolder="text_encoder",
|
|
||||||
revision=args.revision,
|
|
||||||
)
|
|
||||||
vae = AutoencoderKL.from_pretrained(
|
|
||||||
args.pretrained_model_name_or_path,
|
|
||||||
subfolder="vae",
|
|
||||||
revision=args.revision,
|
|
||||||
)
|
)
|
||||||
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||||
unet = UNet2DConditionModel.from_pretrained(
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||||
subfolder="unet",
|
|
||||||
revision=args.revision,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
if not args.train_text_encoder:
|
||||||
|
text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
if args.enable_xformers_memory_efficient_attention:
|
if args.enable_xformers_memory_efficient_attention:
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
unet.enable_xformers_memory_efficient_attention()
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
else:
|
else:
|
||||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||||
|
|
||||||
vae.requires_grad_(False)
|
|
||||||
if not args.train_text_encoder:
|
|
||||||
text_encoder.requires_grad_(False)
|
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# Enable TF32 for faster training on Ampere GPUs,
|
||||||
|
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||||
|
if args.allow_tf32:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = (
|
args.learning_rate = (
|
||||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||||
|
@ -568,6 +602,7 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
|
|
||||||
|
# Optimizer creation
|
||||||
params_to_optimize = (
|
params_to_optimize = (
|
||||||
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
||||||
)
|
)
|
||||||
|
@ -579,8 +614,7 @@ def main(args):
|
||||||
eps=args.adam_epsilon,
|
eps=args.adam_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
# Dataset and DataLoaders creation:
|
||||||
|
|
||||||
train_dataset = DreamBoothDataset(
|
train_dataset = DreamBoothDataset(
|
||||||
instance_data_root=args.instance_data_dir,
|
instance_data_root=args.instance_data_dir,
|
||||||
instance_prompt=args.instance_prompt,
|
instance_prompt=args.instance_prompt,
|
||||||
|
@ -615,6 +649,7 @@ def main(args):
|
||||||
power=args.lr_power,
|
power=args.lr_power,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
|
@ -623,17 +658,16 @@ def main(args):
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
unet, optimizer, train_dataloader, lr_scheduler
|
unet, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
accelerator.register_for_checkpointing(lr_scheduler)
|
|
||||||
|
|
||||||
|
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||||
|
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
weight_dtype = torch.float16
|
weight_dtype = torch.float16
|
||||||
elif accelerator.mixed_precision == "bf16":
|
elif accelerator.mixed_precision == "bf16":
|
||||||
weight_dtype = torch.bfloat16
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
# Move text_encode and vae to gpu.
|
# Move vae and text_encoder to device and cast to weight_dtype
|
||||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
|
||||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
if not args.train_text_encoder:
|
if not args.train_text_encoder:
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
@ -664,6 +698,7 @@ def main(args):
|
||||||
global_step = 0
|
global_step = 0
|
||||||
first_epoch = 0
|
first_epoch = 0
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
if args.resume_from_checkpoint:
|
if args.resume_from_checkpoint:
|
||||||
if args.resume_from_checkpoint != "latest":
|
if args.resume_from_checkpoint != "latest":
|
||||||
path = os.path.basename(args.resume_from_checkpoint)
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
@ -772,9 +807,8 @@ def main(args):
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
# Create the pipeline using using the trained modules and save it.
|
# Create the pipeline using using the trained modules and save it.
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
|
|
|
@ -411,6 +411,7 @@ def main():
|
||||||
logging_dir=logging_dir,
|
logging_dir=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
@ -419,7 +420,7 @@ def main():
|
||||||
logger.info(accelerator.state, main_process_only=False)
|
logger.info(accelerator.state, main_process_only=False)
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
datasets.utils.logging.set_verbosity_warning()
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_warning()
|
||||||
diffusers.utils.logging.set_verbosity_info()
|
diffusers.utils.logging.set_verbosity_info()
|
||||||
else:
|
else:
|
||||||
datasets.utils.logging.set_verbosity_error()
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
@ -577,6 +578,7 @@ def main():
|
||||||
)
|
)
|
||||||
return inputs.input_ids
|
return inputs.input_ids
|
||||||
|
|
||||||
|
# Preprocessing the datasets.
|
||||||
train_transforms = transforms.Compose(
|
train_transforms = transforms.Compose(
|
||||||
[
|
[
|
||||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||||
|
@ -605,6 +607,7 @@ def main():
|
||||||
input_ids = torch.stack([example["input_ids"] for example in examples])
|
input_ids = torch.stack([example["input_ids"] for example in examples])
|
||||||
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
||||||
|
|
||||||
|
# DataLoaders creation:
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
|
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
|
||||||
)
|
)
|
||||||
|
@ -623,6 +626,7 @@ def main():
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
unet, optimizer, train_dataloader, lr_scheduler
|
unet, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
@ -668,6 +672,7 @@ def main():
|
||||||
global_step = 0
|
global_step = 0
|
||||||
first_epoch = 0
|
first_epoch = 0
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
if args.resume_from_checkpoint:
|
if args.resume_from_checkpoint:
|
||||||
if args.resume_from_checkpoint != "latest":
|
if args.resume_from_checkpoint != "latest":
|
||||||
path = os.path.basename(args.resume_from_checkpoint)
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
@ -11,7 +12,10 @@ import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import diffusers
|
||||||
import PIL
|
import PIL
|
||||||
|
import transformers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
@ -207,6 +211,24 @@ def parse_args():
|
||||||
"and an Nvidia Ampere GPU."
|
"and an Nvidia Ampere GPU."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow_tf32",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||||
|
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--report_to",
|
||||||
|
type=str,
|
||||||
|
default="tensorboard",
|
||||||
|
help=(
|
||||||
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||||
|
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||||
|
"Only applicable when `--with_tracking` is passed."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpointing_steps",
|
"--checkpointing_steps",
|
||||||
|
@ -394,10 +416,26 @@ def main():
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
mixed_precision=args.mixed_precision,
|
mixed_precision=args.mixed_precision,
|
||||||
log_with="tensorboard",
|
log_with=args.report_to,
|
||||||
logging_dir=logging_dir,
|
logging_dir=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
logger.info(accelerator.state, main_process_only=False)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
transformers.utils.logging.set_verbosity_warning()
|
||||||
|
diffusers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
diffusers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
# If passed along, set the training seed now.
|
# If passed along, set the training seed now.
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
@ -419,12 +457,22 @@ def main():
|
||||||
elif args.output_dir is not None:
|
elif args.output_dir is not None:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
# Load the tokenizer and add the placeholder token as a additional special token
|
# Load tokenizer
|
||||||
if args.tokenizer_name:
|
if args.tokenizer_name:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||||
elif args.pretrained_model_name_or_path:
|
elif args.pretrained_model_name_or_path:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||||
|
|
||||||
|
# Load scheduler and models
|
||||||
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
|
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||||
|
)
|
||||||
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
|
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||||
|
)
|
||||||
|
|
||||||
# Add the placeholder token in tokenizer
|
# Add the placeholder token in tokenizer
|
||||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||||
if num_added_tokens == 0:
|
if num_added_tokens == 0:
|
||||||
|
@ -442,33 +490,6 @@ def main():
|
||||||
initializer_token_id = token_ids[0]
|
initializer_token_id = token_ids[0]
|
||||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||||
|
|
||||||
# Load models and create wrapper for stable diffusion
|
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
|
||||||
args.pretrained_model_name_or_path,
|
|
||||||
subfolder="text_encoder",
|
|
||||||
revision=args.revision,
|
|
||||||
)
|
|
||||||
vae = AutoencoderKL.from_pretrained(
|
|
||||||
args.pretrained_model_name_or_path,
|
|
||||||
subfolder="vae",
|
|
||||||
revision=args.revision,
|
|
||||||
)
|
|
||||||
unet = UNet2DConditionModel.from_pretrained(
|
|
||||||
args.pretrained_model_name_or_path,
|
|
||||||
subfolder="unet",
|
|
||||||
revision=args.revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
|
||||||
text_encoder.gradient_checkpointing_enable()
|
|
||||||
unet.enable_gradient_checkpointing()
|
|
||||||
|
|
||||||
if args.enable_xformers_memory_efficient_attention:
|
|
||||||
if is_xformers_available():
|
|
||||||
unet.enable_xformers_memory_efficient_attention()
|
|
||||||
else:
|
|
||||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
|
||||||
|
|
||||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
@ -484,6 +505,24 @@ def main():
|
||||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||||
|
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||||
|
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
|
||||||
|
unet.train()
|
||||||
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
if args.enable_xformers_memory_efficient_attention:
|
||||||
|
if is_xformers_available():
|
||||||
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
|
else:
|
||||||
|
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||||
|
|
||||||
|
# Enable TF32 for faster training on Ampere GPUs,
|
||||||
|
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||||
|
if args.allow_tf32:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = (
|
args.learning_rate = (
|
||||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||||
|
@ -498,8 +537,7 @@ def main():
|
||||||
eps=args.adam_epsilon,
|
eps=args.adam_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
# Dataset and DataLoaders creation:
|
||||||
|
|
||||||
train_dataset = TextualInversionDataset(
|
train_dataset = TextualInversionDataset(
|
||||||
data_root=args.train_data_dir,
|
data_root=args.train_data_dir,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -526,26 +564,23 @@ def main():
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
accelerator.register_for_checkpointing(lr_scheduler)
|
|
||||||
|
|
||||||
|
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||||
|
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
weight_dtype = torch.float16
|
weight_dtype = torch.float16
|
||||||
elif accelerator.mixed_precision == "bf16":
|
elif accelerator.mixed_precision == "bf16":
|
||||||
weight_dtype = torch.bfloat16
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
# Move vae and unet to device
|
# Move vae and unet to device and cast to weight_dtype
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
|
||||||
# The dropout is 0 so it doesn't matter if we are in eval or train mode.
|
|
||||||
if args.gradient_checkpointing:
|
|
||||||
unet.train()
|
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
if overrode_max_train_steps:
|
if overrode_max_train_steps:
|
||||||
|
@ -571,6 +606,7 @@ def main():
|
||||||
global_step = 0
|
global_step = 0
|
||||||
first_epoch = 0
|
first_epoch = 0
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
if args.resume_from_checkpoint:
|
if args.resume_from_checkpoint:
|
||||||
if args.resume_from_checkpoint != "latest":
|
if args.resume_from_checkpoint != "latest":
|
||||||
path = os.path.basename(args.resume_from_checkpoint)
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
@ -670,8 +706,8 @@ def main():
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
# Create the pipeline using using the trained modules and save it.
|
# Create the pipeline using using the trained modules and save it.
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
if args.push_to_hub and args.only_save_embeds:
|
if args.push_to_hub and args.only_save_embeds:
|
||||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||||
|
|
Loading…
Reference in New Issue