[Examples] Test all examples on CPU (#2289)

* [Examples] Test all examples on CPU

* add

* correct

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen 2023-02-08 16:59:13 +02:00 committed by GitHub
parent 9d0d070996
commit 1ed6b77781
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 123 additions and 39 deletions

View File

@ -36,6 +36,11 @@ jobs:
runner: docker-cpu
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu
name: ${{ matrix.config.name }}
@ -90,6 +95,13 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples/test_examples.py
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt

View File

@ -25,8 +25,6 @@ from typing import List
from accelerate.utils import write_basic_config
from diffusers.utils import slow
logging.basicConfig(level=logging.DEBUG)
@ -74,51 +72,94 @@ class ExamplesTestsAccelerate(unittest.TestCase):
super().tearDownClass()
shutil.rmtree(cls._tmpdir)
@slow
def test_train_unconditional(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name huggan/few-shot-aurora
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 4
--train_batch_size 2
--num_epochs 1
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
--mixed_precision fp16
""".split()
run_command(self._launch_args + test_args, return_stdout=True)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
# logging test
self.assertTrue(len(os.listdir(os.path.join(tmpdir, "logs", "train_unconditional"))) > 0)
@slow
def test_textual_inversion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token toy
--initializer_token a
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 2
--max_train_steps 10
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--mixed_precision fp16
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin")))
def test_dreambooth(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_text_to_image(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))

View File

@ -22,7 +22,7 @@ import diffusers
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from diffusers.utils import check_min_version, is_tensorboard_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@ -67,6 +67,12 @@ def parse_args():
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--model_config_name_or_path",
type=str,
default=None,
help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
)
parser.add_argument(
"--train_data_dir",
type=str,
@ -222,6 +228,7 @@ def parse_args():
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
parser.add_argument(
"--checkpointing_steps",
@ -340,29 +347,33 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True)
# Initialize the model
model = UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
if args.model_config_name_or_path is None:
model = UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
else:
config = UNet2DModel.load_config(args.model_config_name_or_path)
model = UNet2DModel.from_config(config)
# Create EMA for the model.
if args.use_ema:
@ -586,13 +597,14 @@ def main(args):
images = pipeline(
generator=generator,
batch_size=args.eval_batch_size,
num_inference_steps=args.ddpm_num_inference_steps,
output_type="numpy",
).images
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
if args.logger == "tensorboard":
if args.logger == "tensorboard" and is_tensorboard_available():
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)

View File

@ -52,6 +52,7 @@ from .import_utils import (
is_onnx_available,
is_safetensors_available,
is_scipy_available,
is_tensorboard_available,
is_tf_available,
is_torch_available,
is_torch_version,

View File

@ -224,6 +224,13 @@ try:
except importlib_metadata.PackageNotFoundError:
_omegaconf_available = False
_tensorboard_available = importlib.util.find_spec("tensorboard")
try:
_tensorboard_version = importlib_metadata.version("tensorboard")
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
except importlib_metadata.PackageNotFoundError:
_tensorboard_available = False
def is_torch_available():
return _torch_available
@ -285,6 +292,10 @@ def is_omegaconf_available():
return _omegaconf_available
def is_tensorboard_available():
return _tensorboard_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@ -351,6 +362,12 @@ OMEGACONF_IMPORT_ERROR = """
install omegaconf`
"""
# docstyle-ignore
TENSORBOARD_IMPORT_ERROR = """
{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip
install tensorboard`
"""
BACKENDS_MAPPING = OrderedDict(
[
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
@ -364,6 +381,7 @@ BACKENDS_MAPPING = OrderedDict(
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
]
)