[dreambooth] low precision guard (#1916)
* [dreambooth] low precision guard * fix * add docs to cli args * Update examples/dreambooth/train_dreambooth.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
7101c7316b
commit
247b5feea1
|
@ -70,7 +70,10 @@ def parse_args(input_args=None):
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
help=(
|
||||||
|
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
||||||
|
" float32 precision."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer_name",
|
"--tokenizer_name",
|
||||||
|
@ -140,7 +143,11 @@ def parse_args(input_args=None):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||||
)
|
)
|
||||||
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
|
parser.add_argument(
|
||||||
|
"--train_text_encoder",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||||
)
|
)
|
||||||
|
@ -671,6 +678,17 @@ def main(args):
|
||||||
if not args.train_text_encoder:
|
if not args.train_text_encoder:
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
low_precision_error_string = (
|
||||||
|
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
|
||||||
|
" doing mixed precision training. copy of the weights should still be float32."
|
||||||
|
)
|
||||||
|
|
||||||
|
if unet.dtype != torch.float32:
|
||||||
|
raise ValueError(f"Unet loaded as datatype {unet.dtype}. {low_precision_error_string}")
|
||||||
|
|
||||||
|
if args.train_text_encoder and text_encoder.dtype != torch.float32:
|
||||||
|
raise ValueError(f"Text encoder loaded as datatype {text_encoder.dtype}. {low_precision_error_string}")
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
if overrode_max_train_steps:
|
if overrode_max_train_steps:
|
||||||
|
|
Loading…
Reference in New Issue