fix test_components (#928)
This commit is contained in:
parent
4bf675f465
commit
8be48507a0
|
@ -1405,7 +1405,7 @@ class PipelineFastTests(unittest.TestCase):
|
|||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
inpaint = StableDiffusionInpaintPipeline(
|
||||
inpaint = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
|
@ -1413,9 +1413,9 @@ class PipelineFastTests(unittest.TestCase):
|
|||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components)
|
||||
text2img = StableDiffusionPipeline(**inpaint.components)
|
||||
).to(torch_device)
|
||||
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
|
||||
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
|
Loading…
Reference in New Issue