[examples] misc fixes (#1886)
* misc fixes * more comments * Update examples/textual_inversion/textual_inversion.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * set transformers verbosity to warning Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
423c3a4cc6
commit
fa1f4701e8
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
|
@ -12,6 +13,9 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
|
@ -236,6 +240,24 @@ def parse_args(input_args=None):
|
|||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
|
@ -422,7 +444,7 @@ def main(args):
|
|||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
log_with=args.report_to,
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
|
@ -435,9 +457,27 @@ def main(args):
|
|||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Generate class images if prior preservation is enabled.
|
||||
if args.with_prior_preservation:
|
||||
class_images_dir = Path(args.class_data_dir)
|
||||
if not class_images_dir.exists():
|
||||
|
@ -502,11 +542,7 @@ def main(args):
|
|||
|
||||
# Load the tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name,
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
@ -518,38 +554,36 @@ def main(args):
|
|||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
vae.requires_grad_(False)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
|
@ -568,6 +602,7 @@ def main(args):
|
|||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
||||
)
|
||||
|
@ -579,8 +614,7 @@ def main(args):
|
|||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
|
@ -615,6 +649,7 @@ def main(args):
|
|||
power=args.lr_power,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
|
@ -623,17 +658,16 @@ def main(args):
|
|||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
# Move vae and text_encoder to device and cast to weight_dtype
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
@ -664,6 +698,7 @@ def main(args):
|
|||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
|
@ -772,9 +807,8 @@ def main(args):
|
|||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
|
|
@ -411,6 +411,7 @@ def main():
|
|||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
|
@ -419,7 +420,7 @@ def main():
|
|||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
|
@ -577,6 +578,7 @@ def main():
|
|||
)
|
||||
return inputs.input_ids
|
||||
|
||||
# Preprocessing the datasets.
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
|
@ -605,6 +607,7 @@ def main():
|
|||
input_ids = torch.stack([example["input_ids"] for example in examples])
|
||||
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
|
||||
)
|
||||
|
@ -623,6 +626,7 @@ def main():
|
|||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
@ -668,6 +672,7 @@ def main():
|
|||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
@ -11,7 +12,10 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import PIL
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
|
@ -207,6 +211,24 @@ def parse_args():
|
|||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
|
@ -394,10 +416,26 @@ def main():
|
|||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
log_with=args.report_to,
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
@ -419,12 +457,22 @@ def main():
|
|||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load the tokenizer and add the placeholder token as a additional special token
|
||||
# Load tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
|
@ -442,33 +490,6 @@ def main():
|
|||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
@ -484,6 +505,24 @@ def main():
|
|||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
|
||||
unet.train()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
|
@ -498,8 +537,7 @@ def main():
|
|||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.train_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -526,26 +564,23 @@ def main():
|
|||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move vae and unet to device
|
||||
# Move vae and unet to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
# The dropout is 0 so it doesn't matter if we are in eval or train mode.
|
||||
if args.gradient_checkpointing:
|
||||
unet.train()
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
|
@ -571,6 +606,7 @@ def main():
|
|||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
|
@ -670,8 +706,8 @@ def main():
|
|||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
|
|
Loading…
Reference in New Issue