adds `xformers` support to `train_unconditional.py` (#2520)
This commit is contained in:
parent
7f0f7e1e91
commit
5e5ce13e2f
|
@ -24,6 +24,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
|||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
|
@ -259,6 +260,9 @@ def parse_args():
|
|||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
|
@ -410,6 +414,19 @@ def main(args):
|
|||
model_config=model.config,
|
||||
)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Initialize the scheduler
|
||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
if accepts_prediction_type:
|
||||
|
|
Loading…
Reference in New Issue