adds `xformers` support to `train_unconditional.py` (#2520)
This commit is contained in:
parent
7f0f7e1e91
commit
5e5ce13e2f
|
@ -24,6 +24,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.training_utils import EMAModel
|
from diffusers.training_utils import EMAModel
|
||||||
from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
|
from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
|
||||||
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
|
@ -259,6 +260,9 @@ def parse_args():
|
||||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||||
|
@ -410,6 +414,19 @@ def main(args):
|
||||||
model_config=model.config,
|
model_config=model.config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.enable_xformers_memory_efficient_attention:
|
||||||
|
if is_xformers_available():
|
||||||
|
import xformers
|
||||||
|
|
||||||
|
xformers_version = version.parse(xformers.__version__)
|
||||||
|
if xformers_version == version.parse("0.0.16"):
|
||||||
|
logger.warn(
|
||||||
|
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||||
|
)
|
||||||
|
model.enable_xformers_memory_efficient_attention()
|
||||||
|
else:
|
||||||
|
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||||
|
|
||||||
# Initialize the scheduler
|
# Initialize the scheduler
|
||||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||||
if accepts_prediction_type:
|
if accepts_prediction_type:
|
||||||
|
|
Loading…
Reference in New Issue