From f1d4289be80c5acfc8a1404c01fd324d8011e319 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Oct 2022 13:55:39 +0200 Subject: [PATCH] [Flax] Add test (#824) --- .../pipeline_flax_stable_diffusion.py | 1 - src/diffusers/utils/testing_utils.py | 19 +++--- tests/test_pipelines_flax.py | 62 +++++++++++++++++++ 3 files changed, 74 insertions(+), 8 deletions(-) create mode 100644 tests/test_pipelines_flax.py diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 102a2658..4a252ab6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -52,7 +52,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): dtype: jnp.dtype = jnp.float32, ): super().__init__() - scheduler = scheduler.set_format("np") self.dtype = dtype self.register_modules( diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index f44b9cd3..682a7471 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -7,22 +7,27 @@ from distutils.util import strtobool from pathlib import Path from typing import Union -import torch - import PIL.Image import PIL.ImageOps import requests from packaging import version -from .import_utils import is_flax_available +from .import_utils import is_flax_available, is_torch_available global_rng = random.Random() -torch_device = "cuda" if torch.cuda.is_available() else "cpu" -is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") -if is_torch_higher_equal_than_1_12: - torch_device = "mps" if torch.backends.mps.is_available() else torch_device + +if is_torch_available(): + import torch + + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse( + "1.12" + ) + + if is_torch_higher_equal_than_1_12: + torch_device = "mps" if torch.backends.mps.is_available() else torch_device def get_tests_dir(append_path=None): diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py new file mode 100644 index 00000000..18dce171 --- /dev/null +++ b/tests/test_pipelines_flax.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from diffusers.utils import is_flax_available +from diffusers.utils.testing_utils import require_flax, slow + + +if is_flax_available(): + import jax + from diffusers import FlaxStableDiffusionPipeline + from flax.jax_utils import replicate + from flax.training.common_utils import shard + from jax import pmap + + +@require_flax +@slow +class FlaxPipelineTests(unittest.TestCase): + def test_dummy_all_tpus(self): + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe" + ) + + prompt = ( + "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" + " field, close up, split lighting, cinematic" + ) + + prng_seed = jax.random.PRNGKey(0) + num_inference_steps = 4 + + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = pipeline.prepare_inputs(prompt) + + p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) + + # shard inputs and rng + params = replicate(params) + prng_seed = jax.random.split(prng_seed, 8) + prompt_ids = shard(prompt_ids) + + images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + + assert len(images_pil) == 8