[Flax] Add test (#824)

This commit is contained in:
Patrick von Platen 2022-10-13 13:55:39 +02:00 committed by GitHub
parent 323a9e1f6d
commit f1d4289be8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 8 deletions

View File

@ -52,7 +52,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("np")
self.dtype = dtype self.dtype = dtype
self.register_modules( self.register_modules(

View File

@ -7,19 +7,24 @@ from distutils.util import strtobool
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import torch
import PIL.Image import PIL.Image
import PIL.ImageOps import PIL.ImageOps
import requests import requests
from packaging import version from packaging import version
from .import_utils import is_flax_available from .import_utils import is_flax_available, is_torch_available
global_rng = random.Random() global_rng = random.Random()
if is_torch_available():
import torch
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
"1.12"
)
if is_torch_higher_equal_than_1_12: if is_torch_higher_equal_than_1_12:
torch_device = "mps" if torch.backends.mps.is_available() else torch_device torch_device = "mps" if torch.backends.mps.is_available() else torch_device

View File

@ -0,0 +1,62 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
from diffusers import FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import pmap
@require_flax
@slow
class FlaxPipelineTests(unittest.TestCase):
def test_dummy_all_tpus(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe"
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 4
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
assert len(images_pil) == 8