[examples] misc fixes (#1886)

* misc fixes

* more comments

* Update examples/textual_inversion/textual_inversion.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* set transformers verbosity to warning

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil 2023-01-02 14:09:01 +01:00 committed by GitHub
parent 423c3a4cc6
commit fa1f4701e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 145 additions and 70 deletions

View File

@ -1,6 +1,7 @@
import argparse import argparse
import hashlib import hashlib
import itertools import itertools
import logging
import math import math
import os import os
import warnings import warnings
@ -12,6 +13,9 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.utils.data import Dataset from torch.utils.data import Dataset
import datasets
import diffusers
import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
@ -236,6 +240,24 @@ def parse_args(input_args=None):
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
), ),
) )
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
type=str, type=str,
@ -422,7 +444,7 @@ def main(args):
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with="tensorboard", log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
) )
@ -435,9 +457,27 @@ def main(args):
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future." "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
) )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation: if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir) class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists(): if not class_images_dir.exists():
@ -502,11 +542,7 @@ def main(args):
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
args.tokenizer_name,
revision=args.revision,
use_fast=False,
)
elif args.pretrained_model_name_or_path: elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
@ -518,38 +554,36 @@ def main(args):
# import correct text encoder class # import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load models and create wrapper for stable diffusion # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained( text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
subfolder="unet",
revision=args.revision,
) )
vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
else: else:
raise ValueError("xformers is not available. Make sure it is installed correctly") raise ValueError("xformers is not available. Make sure it is installed correctly")
vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.train_text_encoder: if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr: if args.scale_lr:
args.learning_rate = ( args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
@ -568,6 +602,7 @@ def main(args):
else: else:
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = ( params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
) )
@ -579,8 +614,7 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
@ -615,6 +649,7 @@ def main(args):
power=args.lr_power, power=args.lr_power,
) )
# Prepare everything with our `accelerator`.
if args.train_text_encoder: if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler unet, text_encoder, optimizer, train_dataloader, lr_scheduler
@ -623,17 +658,16 @@ def main(args):
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
accelerator.register_for_checkpointing(lr_scheduler)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu. # Move vae and text_encoder to device and cast to weight_dtype
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder: if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
@ -664,6 +698,7 @@ def main(args):
global_step = 0 global_step = 0
first_epoch = 0 first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest": if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
@ -772,9 +807,8 @@ def main(args):
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,

View File

@ -411,6 +411,7 @@ def main():
logging_dir=logging_dir, logging_dir=logging_dir,
) )
# Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
@ -419,7 +420,7 @@ def main():
logger.info(accelerator.state, main_process_only=False) logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process: if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning() datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info() diffusers.utils.logging.set_verbosity_info()
else: else:
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
@ -577,6 +578,7 @@ def main():
) )
return inputs.input_ids return inputs.input_ids
# Preprocessing the datasets.
train_transforms = transforms.Compose( train_transforms = transforms.Compose(
[ [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
@ -605,6 +607,7 @@ def main():
input_ids = torch.stack([example["input_ids"] for example in examples]) input_ids = torch.stack([example["input_ids"] for example in examples])
return {"pixel_values": pixel_values, "input_ids": input_ids} return {"pixel_values": pixel_values, "input_ids": input_ids}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
) )
@ -623,6 +626,7 @@ def main():
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
) )
# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
@ -668,6 +672,7 @@ def main():
global_step = 0 global_step = 0
first_epoch = 0 first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest": if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)

View File

@ -1,4 +1,5 @@
import argparse import argparse
import logging
import math import math
import os import os
import random import random
@ -11,7 +12,10 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.utils.data import Dataset from torch.utils.data import Dataset
import datasets
import diffusers
import PIL import PIL
import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
@ -207,6 +211,24 @@ def parse_args():
"and an Nvidia Ampere GPU." "and an Nvidia Ampere GPU."
), ),
) )
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument( parser.add_argument(
"--checkpointing_steps", "--checkpointing_steps",
@ -394,10 +416,26 @@ def main():
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with="tensorboard", log_with=args.report_to,
logging_dir=logging_dir, logging_dir=logging_dir,
) )
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now. # If passed along, set the training seed now.
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
@ -419,12 +457,22 @@ def main():
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer and add the placeholder token as a additional special token # Load tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path: elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# Add the placeholder token in tokenizer # Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(args.placeholder_token) num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
if num_added_tokens == 0: if num_added_tokens == 0:
@ -442,33 +490,6 @@ def main():
initializer_token_id = token_ids[0] initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Resize the token embeddings as we are adding new special tokens to the tokenizer # Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer)) text_encoder.resize_token_embeddings(len(tokenizer))
@ -484,6 +505,24 @@ def main():
text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing:
# Keep unet in train mode if we are using gradient checkpointing to save memory.
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
unet.train()
text_encoder.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr: if args.scale_lr:
args.learning_rate = ( args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
@ -498,8 +537,7 @@ def main():
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Dataset and DataLoaders creation:
train_dataset = TextualInversionDataset( train_dataset = TextualInversionDataset(
data_root=args.train_data_dir, data_root=args.train_data_dir,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -526,26 +564,23 @@ def main():
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
) )
# Prepare everything with our `accelerator`.
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler text_encoder, optimizer, train_dataloader, lr_scheduler
) )
accelerator.register_for_checkpointing(lr_scheduler)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16": elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
# Move vae and unet to device # Move vae and unet to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
# Keep unet in train mode if we are using gradient checkpointing to save memory.
# The dropout is 0 so it doesn't matter if we are in eval or train mode.
if args.gradient_checkpointing:
unet.train()
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if overrode_max_train_steps:
@ -571,6 +606,7 @@ def main():
global_step = 0 global_step = 0
first_epoch = 0 first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest": if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
@ -670,8 +706,8 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds: if args.push_to_hub and args.only_save_embeds:
logger.warn("Enabling full model saving because --push_to_hub=True was specified.") logger.warn("Enabling full model saving because --push_to_hub=True was specified.")