Remove the last of ["sample"] (#842)
This commit is contained in:
parent
52394b53e2
commit
1d3234cbca
|
@ -1693,9 +1693,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
|
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
|
ddim_images = ddim(
|
||||||
"sample"
|
batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy"
|
||||||
]
|
).images
|
||||||
|
|
||||||
# the values aren't exactly equal, but the images look the same visually
|
# the values aren't exactly equal, but the images look the same visually
|
||||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||||
|
@ -1729,9 +1729,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "a photograph of an astronaut riding a horse"
|
prompt = "a photograph of an astronaut riding a horse"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
|
image = pipe(
|
||||||
"sample"
|
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||||
]
|
).images
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
|
|
Loading…
Reference in New Issue