fix test_components (#928)

This commit is contained in:
Suraj Patil 2022-10-20 16:25:12 +02:00 committed by GitHub
parent 4bf675f465
commit 8be48507a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -1405,7 +1405,7 @@ class PipelineFastTests(unittest.TestCase):
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk # make sure here that pndm scheduler skips prk
inpaint = StableDiffusionInpaintPipeline( inpaint = StableDiffusionInpaintPipelineLegacy(
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
vae=vae, vae=vae,
@ -1413,9 +1413,9 @@ class PipelineFastTests(unittest.TestCase):
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker, safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) ).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components) img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components) text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)