Remove the last of ["sample"] (#842)

This commit is contained in:
Anton Lozhkov 2022-10-14 14:45:43 +02:00 committed by GitHub
parent 52394b53e2
commit 1d3234cbca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 6 deletions

View File

@ -1693,9 +1693,9 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
"sample"
]
ddim_images = ddim(
batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy"
).images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
@ -1729,9 +1729,9 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
"sample"
]
image = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)