Here, let's fix this while we're at it.
This commit is contained in:
parent
d1eb3ace3f
commit
189f621a1e
|
@ -48,6 +48,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# defaults should be good for everyone
|
# defaults should be good for everyone
|
||||||
# TODO: add custom VAE support. should be simple with diffusers
|
# TODO: add custom VAE support. should be simple with diffusers
|
||||||
|
# use action='store_bool' when looking for boolean values so the arguments are treated like flags (as expected)
|
||||||
|
# just keep in mind it's logically flipped from 'default',
|
||||||
|
# ('--foo', action='store_false') returns false when the flag exists, and true if it does not.
|
||||||
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
||||||
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
|
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
|
||||||
parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.')
|
parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.')
|
||||||
|
@ -59,10 +62,10 @@ parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximu
|
||||||
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
|
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
|
||||||
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for')
|
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for')
|
||||||
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
|
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
|
||||||
parser.add_argument('--use_ema', type=str, default='False', help='Use EMA for finetuning')
|
parser.add_argument('--use_ema', action='store_true', help='Use EMA for finetuning')
|
||||||
parser.add_argument('--ucg', type=float, default=0.1, help='Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout.') # 10% dropout probability
|
parser.add_argument('--ucg', type=float, default=0.1, help='Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout.') # 10% dropout probability
|
||||||
parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', type=str, default='False', help='Enable gradient checkpointing')
|
parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', action='store_true', help='Enable gradient checkpointing')
|
||||||
parser.add_argument('--use_8bit_adam', dest='use_8bit_adam', type=str, default='False', help='Use 8-bit Adam optimizer')
|
parser.add_argument('--use_8bit_adam', dest='use_8bit_adam', action='store_true', help='Use 8-bit Adam optimizer')
|
||||||
parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1')
|
parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1')
|
||||||
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
|
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
|
||||||
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay')
|
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay')
|
||||||
|
@ -73,31 +76,24 @@ parser.add_argument('--seed', type=int, default=42, help='Seed for random number
|
||||||
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
|
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
|
||||||
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
|
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
|
||||||
parser.add_argument('--resolution', type=int, default=512, help='Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down.')
|
parser.add_argument('--resolution', type=int, default=512, help='Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down.')
|
||||||
parser.add_argument('--shuffle', dest='shuffle', type=str, default='True', help='Shuffle dataset')
|
parser.add_argument('--shuffle', dest='shuffle', action='store_true', help='Shuffle dataset')
|
||||||
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
|
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
|
||||||
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
|
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
|
||||||
parser.add_argument('--fp16', dest='fp16', type=str, default='False', help='Train in mixed precision')
|
parser.add_argument('--fp16', dest='fp16', action='store_true', help='Train in mixed precision')
|
||||||
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
||||||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||||
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
|
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
|
||||||
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
||||||
parser.add_argument('--clip_penultimate', type=str, default='False', help='Use penultimate CLIP layer for text embedding')
|
parser.add_argument('--clip_penultimate', action='store_true', help='Use penultimate CLIP layer for text embedding')
|
||||||
parser.add_argument('--output_bucket_info', type=str, default='False', help='Outputs bucket information and exits')
|
parser.add_argument('--output_bucket_info', action='store_true', help='Outputs bucket information and exits')
|
||||||
parser.add_argument('--resize', type=str, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
|
parser.add_argument('--resize', action='store_true', help="Resizes dataset's images to the appropriate bucket dimensions.")
|
||||||
parser.add_argument('--use_xformers', type=str, default='False', help='Use memory efficient attention')
|
parser.add_argument('--use_xformers', action='store_true', help='Use memory efficient attention')
|
||||||
parser.add_argument('--extended_validation', type=str, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
|
parser.add_argument('--extended_validation', action='store_true', help='Perform extended validation of images to catch truncated or corrupt images.')
|
||||||
parser.add_argument('--data_migration', type=str, default='True', help='Perform migration of resized images into a directory relative to the dataset path. Saves into `<dataset_directory_name>_cropped`.')
|
parser.add_argument('--no_migration', action='store_true', help='Perform migration of resized images into a directory relative to the dataset path. Saves into `<dataset_directory_name>_cropped`.')
|
||||||
parser.add_argument('--skip_validation', type=str, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
parser.add_argument('--skip_validation', action='store_true', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
for arg in vars(args):
|
|
||||||
if type(getattr(args, arg)) == str:
|
|
||||||
if getattr(args, arg).lower() == 'true':
|
|
||||||
setattr(args, arg, True)
|
|
||||||
elif getattr(args, arg).lower() == 'false':
|
|
||||||
setattr(args, arg, False)
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
torch.distributed.init_process_group("nccl", init_method="env://")
|
torch.distributed.init_process_group("nccl", init_method="env://")
|
||||||
|
|
||||||
|
@ -194,12 +190,12 @@ class Validation():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class Resize():
|
class Resize():
|
||||||
def __init__(self, is_resizing: bool, is_migrating: bool) -> None:
|
def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None:
|
||||||
if not is_resizing:
|
if not is_resizing:
|
||||||
self.resize = self.__no_op
|
self.resize = self.__no_op
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_migrating:
|
if not is_not_migrating:
|
||||||
self.resize = self.__migration
|
self.resize = self.__migration
|
||||||
dataset_path = os.path.split(args.dataset)
|
dataset_path = os.path.split(args.dataset)
|
||||||
self.__directory = os.path.join(
|
self.__directory = os.path.join(
|
||||||
|
@ -267,7 +263,7 @@ class ImageStore:
|
||||||
args.extended_validation
|
args.extended_validation
|
||||||
).validate
|
).validate
|
||||||
|
|
||||||
self.resizer = Resize(args.resize, args.data_migration).resize
|
self.resizer = Resize(args.resize, args.no_migration).resize
|
||||||
|
|
||||||
self.image_files = [x for x in self.image_files if self.validator(x)]
|
self.image_files = [x for x in self.image_files if self.validator(x)]
|
||||||
|
|
||||||
|
@ -718,7 +714,7 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Migrate dataset
|
# Migrate dataset
|
||||||
if args.resize and args.data_migration:
|
if args.resize and not args.no_migration:
|
||||||
for _, batch in enumerate(train_dataloader):
|
for _, batch in enumerate(train_dataloader):
|
||||||
continue
|
continue
|
||||||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||||
|
|
Loading…
Reference in New Issue