fix test_components (#928)
This commit is contained in:
parent
4bf675f465
commit
8be48507a0
|
@ -1405,7 +1405,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||||
|
|
||||||
# make sure here that pndm scheduler skips prk
|
# make sure here that pndm scheduler skips prk
|
||||||
inpaint = StableDiffusionInpaintPipeline(
|
inpaint = StableDiffusionInpaintPipelineLegacy(
|
||||||
unet=unet,
|
unet=unet,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
|
@ -1413,9 +1413,9 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
safety_checker=self.dummy_safety_checker,
|
safety_checker=self.dummy_safety_checker,
|
||||||
feature_extractor=self.dummy_extractor,
|
feature_extractor=self.dummy_extractor,
|
||||||
)
|
).to(torch_device)
|
||||||
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components)
|
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
|
||||||
text2img = StableDiffusionPipeline(**inpaint.components)
|
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
|
|
Loading…
Reference in New Issue