parent
2d35f6733a
commit
2c82e0c4eb
|
@ -14,7 +14,7 @@ import PIL.ImageOps
|
|||
import requests
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_flax_available, is_torch_available
|
||||
from .import_utils import is_flax_available, is_onnx_available, is_torch_available
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
@ -100,6 +100,20 @@ def slow(test_case):
|
|||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_flax(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||
|
@ -107,6 +121,13 @@ def require_flax(test_case):
|
|||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||
|
||||
|
||||
def require_onnxruntime(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
|
||||
|
||||
|
||||
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
"""
|
||||
Args:
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@property
|
||||
def dummy_uncond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
return model
|
||||
|
||||
def test_inference(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array(
|
||||
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
|
||||
)
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_ema_bedroom(self):
|
||||
model_id = "google/ddpm-ema-bedroom-256"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddim.to(torch_device)
|
||||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
@ -0,0 +1,55 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# FIXME: add fast tests
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
@ -0,0 +1,86 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import KarrasVePipeline, KarrasVeScheduler, UNet2DModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@property
|
||||
def dummy_uncond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
return model
|
||||
|
||||
def test_inference(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = KarrasVeScheduler()
|
||||
|
||||
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
model_id = "google/ncsnpp-celebahq-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = KarrasVeScheduler()
|
||||
|
||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
@ -0,0 +1,153 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
def test_inference_text2img(self):
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy"
|
||||
).images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy"
|
||||
).images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ldm(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="numpy",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_text2img(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy"
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_text2img_fast(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
@ -0,0 +1,119 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, LDMPipeline, UNet2DModel, VQModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
from transformers import CLIPTextConfig, CLIPTextModel
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class LDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@property
|
||||
def dummy_uncond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vq_model(self):
|
||||
torch.manual_seed(0)
|
||||
model = VQModel(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=3,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
def test_inference_uncond(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
vae = self.dummy_vq_model
|
||||
|
||||
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class LDMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_uncond(self):
|
||||
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
@ -0,0 +1,87 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import PNDMPipeline, PNDMScheduler, UNet2DModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class PNDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@property
|
||||
def dummy_uncond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
return model
|
||||
|
||||
def test_inference(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = PNDMScheduler()
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
pndm.to(torch_device)
|
||||
pndm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class PNDMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = PNDMScheduler()
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
pndm.to(torch_device)
|
||||
pndm.set_progress_bar_config(disable=None)
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
@ -0,0 +1,91 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@property
|
||||
def dummy_uncond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
return model
|
||||
|
||||
def test_inference(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = ScoreSdeVeScheduler()
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
|
||||
sde_ve.to(torch_device)
|
||||
sde_ve.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[
|
||||
0
|
||||
]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
model_id = "google/ncsnpp-church-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||
sde_ve.to(torch_device)
|
||||
sde_ve.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
|
||||
expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
@ -0,0 +1,84 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import require_onnxruntime, slow
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
|
||||
|
||||
class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
|
||||
# FIXME: add fast tests
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_onnxruntime
|
||||
class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
np.random.seed(0)
|
||||
output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=5, output_type="np")
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 5:
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
np.random.seed(0)
|
||||
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 6
|
|
@ -0,0 +1,62 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils.testing_utils import load_image, require_onnxruntime, slow
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
|
||||
|
||||
class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
|
||||
# FIXME: add fast tests
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_onnxruntime
|
||||
class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
np.random.seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=8,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 383:386, -1]
|
||||
|
||||
assert images.shape == (1, 512, 768, 3)
|
||||
expected_slice = np.array([0.4830, 0.5242, 0.5603, 0.5016, 0.5131, 0.5111, 0.4928, 0.5025, 0.5055])
|
||||
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
|
@ -0,0 +1,65 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||
from diffusers.utils.testing_utils import load_image, require_onnxruntime, slow
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
|
||||
|
||||
class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
|
||||
# FIXME: add fast tests
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_onnxruntime
|
||||
class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_stable_diffusion_inpaint_onnx(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
|
||||
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
np.random.seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=8,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 255:258, -1]
|
||||
|
||||
assert images.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
|
@ -31,13 +31,16 @@ from diffusers import (
|
|||
VQModel,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
@ -514,8 +517,8 @@ class PipelineFastTests(unittest.TestCase):
|
|||
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
class PipelineIntegrationTests(unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
|
|
@ -30,13 +30,16 @@ from diffusers import (
|
|||
VQModel,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
@ -461,8 +464,8 @@ class PipelineFastTests(unittest.TestCase):
|
|||
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
class PipelineIntegrationTests(unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
|
|
@ -29,14 +29,17 @@ from diffusers import (
|
|||
VQModel,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
@ -258,8 +261,8 @@ class PipelineFastTests(unittest.TestCase):
|
|||
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
class PipelineIntegrationTests(unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
|
|
@ -31,14 +31,17 @@ from diffusers import (
|
|||
VQModel,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
@ -338,8 +341,8 @@ class PipelineFastTests(unittest.TestCase):
|
|||
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
class PipelineIntegrationTests(unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
|
|
@ -32,17 +32,7 @@ from diffusers import (
|
|||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
DDPMScheduler,
|
||||
KarrasVePipeline,
|
||||
KarrasVeScheduler,
|
||||
LDMPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionPipeline,
|
||||
PNDMPipeline,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVePipeline,
|
||||
ScoreSdeVeScheduler,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
|
@ -53,8 +43,8 @@ from diffusers import (
|
|||
)
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
@ -145,12 +135,6 @@ class CustomPipelineTests(unittest.TestCase):
|
|||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
|
@ -261,174 +245,6 @@ class PipelineFastTests(unittest.TestCase):
|
|||
|
||||
return extract
|
||||
|
||||
def test_ddim(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array(
|
||||
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
|
||||
)
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
||||
|
||||
def test_pndm_cifar10(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = PNDMScheduler()
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
pndm.to(torch_device)
|
||||
pndm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_ldm_text2img(self):
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy"
|
||||
).images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy"
|
||||
).images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ldm(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="numpy",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = ScoreSdeVeScheduler()
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
|
||||
sde_ve.to(torch_device)
|
||||
sde_ve.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[
|
||||
0
|
||||
]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_ldm_uncond(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
vae = self.dummy_vq_model
|
||||
|
||||
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_karras_ve_pipeline(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = KarrasVeScheduler()
|
||||
|
||||
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_components(self):
|
||||
"""Test that components property works correctly"""
|
||||
unet = self.dummy_cond_unet
|
||||
|
@ -489,7 +305,8 @@ class PipelineFastTests(unittest.TestCase):
|
|||
assert image_text2img.shape == (1, 128, 128, 3)
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
@slow
|
||||
class PipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
|
@ -556,7 +373,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
|
@ -577,7 +393,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub_pass_model(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
|
@ -601,7 +416,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
|
@ -624,156 +438,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
@slow
|
||||
def test_ddpm_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddim_lsun(self):
|
||||
model_id = "google/ddpm-ema-bedroom-256"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddim_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddim.to(torch_device)
|
||||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_pndm_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = PNDMScheduler()
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
pndm.to(torch_device)
|
||||
pndm.set_progress_bar_config(disable=None)
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy"
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img_fast(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model_id = "google/ncsnpp-church-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||
sde_ve.to(torch_device)
|
||||
sde_ve.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
|
||||
expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_uncond(self):
|
||||
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddpm_ddim_equality(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
|
@ -824,145 +488,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||
|
||||
@slow
|
||||
def test_karras_ve_pipeline(self):
|
||||
model_id = "google/ncsnpp-celebahq-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = KarrasVeScheduler()
|
||||
|
||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx(self):
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
np.random.seed(0)
|
||||
output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=5, output_type="np")
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_img2img_onnx(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
np.random.seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=8,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 383:386, -1]
|
||||
|
||||
assert images.shape == (1, 512, 768, 3)
|
||||
expected_slice = np.array([0.4830, 0.5242, 0.5603, 0.5016, 0.5131, 0.5111, 0.4928, 0.5025, 0.5055])
|
||||
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_inpaint_onnx(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
|
||||
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
np.random.seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=8,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 255:258, -1]
|
||||
|
||||
assert images.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 5:
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
np.random.seed(0)
|
||||
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 6
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_accelerate_load_works(self):
|
||||
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
|
||||
return
|
||||
|
@ -975,8 +501,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||
).to(torch_device)
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
|
||||
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
|
||||
return
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
from diffusers.utils.testing_utils import require_torch
|
||||
|
||||
|
||||
@require_torch
|
||||
class PipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline,
|
||||
equivalence of dict and tuple outputs, etc.
|
||||
"""
|
||||
|
||||
pass
|
|
@ -30,8 +30,8 @@ if is_flax_available():
|
|||
from jax import pmap
|
||||
|
||||
|
||||
@require_flax
|
||||
@slow
|
||||
@require_flax
|
||||
class FlaxPipelineTests(unittest.TestCase):
|
||||
def test_dummy_all_tpus(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
from diffusers.utils.testing_utils import require_onnxruntime
|
||||
|
||||
|
||||
@require_onnxruntime
|
||||
class OnnxPipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each ONNXRuntime pipeline, e.g. saving and loading the pipeline,
|
||||
equivalence of dict and tuple outputs, etc.
|
||||
"""
|
||||
|
||||
pass
|
Loading…
Reference in New Issue