[Flax] Complete tests (#828)
This commit is contained in:
parent
7c2262640b
commit
1d51224403
|
@ -24,7 +24,7 @@ from diffusers.utils.testing_utils import require_flax, slow
|
|||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from diffusers import FlaxStableDiffusionPipeline
|
||||
from diffusers import FlaxDDIMScheduler, FlaxStableDiffusionPipeline
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
from jax import pmap
|
||||
|
@ -61,7 +61,7 @@ class FlaxPipelineTests(unittest.TestCase):
|
|||
|
||||
assert images.shape == (8, 1, 64, 64, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 1e-2
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 5e-1
|
||||
|
||||
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
|
||||
|
@ -93,13 +93,9 @@ class FlaxPipelineTests(unittest.TestCase):
|
|||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
|
||||
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
for i, image in enumerate(images_pil):
|
||||
image.save(f"/home/patrick/images/flax-test-{i}_fp32.png")
|
||||
|
||||
assert images.shape == (8, 1, 512, 512, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 1e-2
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
|
@ -129,7 +125,7 @@ class FlaxPipelineTests(unittest.TestCase):
|
|||
|
||||
assert images.shape == (8, 1, 512, 512, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
|
@ -157,4 +153,49 @@ class FlaxPipelineTests(unittest.TestCase):
|
|||
|
||||
assert images.shape == (8, 1, 512, 512, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
|
||||
scheduler = FlaxDDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
)
|
||||
scheduler_state = scheduler.create_state()
|
||||
|
||||
params["scheduler"] = scheduler_state
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, 8)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
|
||||
assert images.shape == (8, 1, 512, 512, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
|
||||
|
|
Loading…
Reference in New Issue