parent
bc2ad5a661
commit
e8140304b9
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
@ -22,6 +23,7 @@ import torch
|
||||||
import PIL
|
import PIL
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
DDIMPipeline,
|
DDIMPipeline,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMPipeline,
|
DDPMPipeline,
|
||||||
|
@ -38,10 +40,14 @@ from diffusers import (
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
StableDiffusionInpaintPipeline,
|
StableDiffusionInpaintPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
|
UNet2DConditionModel,
|
||||||
UNet2DModel,
|
UNet2DModel,
|
||||||
|
VQModel,
|
||||||
)
|
)
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.testing_utils import slow, torch_device
|
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
@ -70,6 +76,410 @@ def test_progress_bar(capsys):
|
||||||
assert captured.err == "", "Progress bar should be disabled"
|
assert captured.err == "", "Progress bar should be disabled"
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineFastTests(unittest.TestCase):
|
||||||
|
@property
|
||||||
|
def dummy_image(self):
|
||||||
|
batch_size = 1
|
||||||
|
num_channels = 3
|
||||||
|
sizes = (32, 32)
|
||||||
|
|
||||||
|
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
|
||||||
|
return image
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_uncond_unet(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = UNet2DModel(
|
||||||
|
block_out_channels=(32, 64),
|
||||||
|
layers_per_block=2,
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||||
|
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_cond_unet(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = UNet2DConditionModel(
|
||||||
|
block_out_channels=(32, 64),
|
||||||
|
layers_per_block=2,
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||||
|
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||||
|
cross_attention_dim=32,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_vq_model(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = VQModel(
|
||||||
|
block_out_channels=[32, 64],
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||||
|
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||||
|
latent_channels=3,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_vae(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = AutoencoderKL(
|
||||||
|
block_out_channels=[32, 64],
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||||
|
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||||
|
latent_channels=4,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_text_encoder(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config = CLIPTextConfig(
|
||||||
|
bos_token_id=0,
|
||||||
|
chunk_size_feed_forward=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=37,
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
pad_token_id=1,
|
||||||
|
vocab_size=1000,
|
||||||
|
)
|
||||||
|
return CLIPTextModel(config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_safety_checker(self):
|
||||||
|
def check(images, *args, **kwargs):
|
||||||
|
return images, False
|
||||||
|
|
||||||
|
return check
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_extractor(self):
|
||||||
|
def extract(*args, **kwargs):
|
||||||
|
class Out:
|
||||||
|
def __init__(self):
|
||||||
|
self.pixel_values = torch.ones([0])
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.pixel_values.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
return Out()
|
||||||
|
|
||||||
|
return extract
|
||||||
|
|
||||||
|
def test_ddim(self):
|
||||||
|
unet = self.dummy_uncond_unet
|
||||||
|
scheduler = DDIMScheduler(tensor_format="pt")
|
||||||
|
|
||||||
|
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
ddpm.to(torch_device)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
expected_slice = np.array(
|
||||||
|
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
|
||||||
|
)
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_pndm_cifar10(self):
|
||||||
|
unet = self.dummy_uncond_unet
|
||||||
|
scheduler = PNDMScheduler(tensor_format="pt")
|
||||||
|
|
||||||
|
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
pndm.to(torch_device)
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_ldm_text2img(self):
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = DDIMScheduler(tensor_format="pt")
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||||
|
ldm.to(torch_device)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[
|
||||||
|
"sample"
|
||||||
|
]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 64, 64, 3)
|
||||||
|
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_ddim(self):
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = DDIMScheduler(
|
||||||
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
clip_sample=False,
|
||||||
|
set_alpha_to_one=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
output = sd_pipe(
|
||||||
|
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 128, 128, 3)
|
||||||
|
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_pndm(self):
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
output = sd_pipe(
|
||||||
|
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 128, 128, 3)
|
||||||
|
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_k_lms(self):
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
|
||||||
|
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
output = sd_pipe(
|
||||||
|
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 128, 128, 3)
|
||||||
|
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_score_sde_ve_pipeline(self):
|
||||||
|
unet = self.dummy_uncond_unet
|
||||||
|
scheduler = ScoreSdeVeScheduler(tensor_format="pt")
|
||||||
|
|
||||||
|
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
|
||||||
|
sde_ve.to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
|
||||||
|
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_ldm_uncond(self):
|
||||||
|
unet = self.dummy_uncond_unet
|
||||||
|
scheduler = DDIMScheduler(tensor_format="pt")
|
||||||
|
vae = self.dummy_vq_model
|
||||||
|
|
||||||
|
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
|
||||||
|
ldm.to(torch_device)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 64, 64, 3)
|
||||||
|
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_karras_ve_pipeline(self):
|
||||||
|
unet = self.dummy_uncond_unet
|
||||||
|
scheduler = KarrasVeScheduler(tensor_format="pt")
|
||||||
|
|
||||||
|
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
|
||||||
|
pipe.to(torch_device)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_img2img(self):
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
init_image = self.dummy_image
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
output = sd_pipe(
|
||||||
|
[prompt],
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=6.0,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_inpaint(self):
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
image = self.dummy_image.permute(0, 2, 3, 1)[0]
|
||||||
|
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||||
|
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionInpaintPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
output = sd_pipe(
|
||||||
|
[prompt],
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=6.0,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
mask_image=mask_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output["sample"]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 32, 32, 3)
|
||||||
|
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
|
||||||
class PipelineTesterMixin(unittest.TestCase):
|
class PipelineTesterMixin(unittest.TestCase):
|
||||||
def test_from_pretrained_save_pretrained(self):
|
def test_from_pretrained_save_pretrained(self):
|
||||||
# 1. Load models
|
# 1. Load models
|
||||||
|
|
Loading…
Reference in New Issue