Reproducibility 3/3 (#1924)
* make tests deterministic * run slow tests * prepare for testing * finish * refactor * add print statements * finish more * correct some test failures * more fixes * set up to correct tests * more corrections * up * fix more * more prints * add * up * up * up * uP * uP * more fixes * uP * up * up * up * up * fix more * up * up * clean tests * up * up * up * more fixes * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * make * correct * finish * finish Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
008c22d334
commit
6ba2231d72
|
@ -32,6 +32,8 @@
|
||||||
title: Text-Guided Depth-to-Image
|
title: Text-Guided Depth-to-Image
|
||||||
- local: using-diffusers/reusing_seeds
|
- local: using-diffusers/reusing_seeds
|
||||||
title: Reusing seeds for deterministic generation
|
title: Reusing seeds for deterministic generation
|
||||||
|
- local: using-diffusers/reproducibility
|
||||||
|
title: Reproducibility
|
||||||
- local: using-diffusers/custom_pipeline_examples
|
- local: using-diffusers/custom_pipeline_examples
|
||||||
title: Community Pipelines
|
title: Community Pipelines
|
||||||
- local: using-diffusers/contribute_pipeline
|
- local: using-diffusers/contribute_pipeline
|
||||||
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
|
||||||
|
Before reading about reproducibility for Diffusers, it is strongly recommended to take a look at
|
||||||
|
[PyTorch's statement about reproducibility](https://pytorch.org/docs/stable/notes/randomness.html).
|
||||||
|
|
||||||
|
PyTorch states that
|
||||||
|
> *completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms.*
|
||||||
|
While one can never expect the same results across platforms, one can expect results to be reproducible
|
||||||
|
across releases, platforms, etc... within a certain tolerance. However, this tolerance strongly varies
|
||||||
|
depending on the diffusion pipeline and checkpoint.
|
||||||
|
|
||||||
|
In the following, we show how to best control sources of randomness for diffusion models.
|
||||||
|
|
||||||
|
## Inference
|
||||||
|
|
||||||
|
During inference, diffusion pipelines heavily rely on random sampling operations, such as the creating the
|
||||||
|
gaussian noise tensors to be denoised and adding noise to the scheduling step.
|
||||||
|
|
||||||
|
Let's have a look at an example. We run the [DDIM pipeline](./api/pipelines/ddim.mdx)
|
||||||
|
for just two inference steps and return a numpy tensor to look into the numerical values of the output.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DDIMPipeline
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
|
# load model and scheduler
|
||||||
|
ddim = DDIMPipeline.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# run pipeline for just two steps and return numpy tensor
|
||||||
|
image = ddim(num_inference_steps=2, output_type="np").images
|
||||||
|
print(np.abs(image).sum())
|
||||||
|
```
|
||||||
|
|
||||||
|
Running the above prints a value of 1464.2076, but running it again prints a different
|
||||||
|
value of 1495.1768. What is going on here? Every time the pipeline is run, gaussian noise
|
||||||
|
is created and step-wise denoised. To create the gaussian noise with [`torch.randn`](https://pytorch.org/docs/stable/generated/torch.randn.html), a different random seed is taken every time, thus leading to a different result.
|
||||||
|
This is a desired property of diffusion pipelines, as it means that the pipeline can create a different random image every time it is run. In many cases, one would like to generate the exact same image of a certain
|
||||||
|
run, for which case an instance of a [PyTorch generator](https://pytorch.org/docs/stable/generated/torch.randn.html) has to be passed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DDIMPipeline
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
|
# load model and scheduler
|
||||||
|
ddim = DDIMPipeline.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# create a generator for reproducibility
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||||
|
|
||||||
|
# run pipeline for just two steps and return numpy tensor
|
||||||
|
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
|
||||||
|
print(np.abs(image).sum())
|
||||||
|
```
|
||||||
|
|
||||||
|
Running the above always prints a value of 1491.1711 - also upon running it again because we
|
||||||
|
define the generator object to be passed to all random functions of the pipeline.
|
||||||
|
|
||||||
|
If you run this code snippet on your specific hardware and version, you should get a similar, if not the same, result.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It might be a bit unintuitive at first to pass `generator` objects to the pipelines instead of
|
||||||
|
just integer values representing the seed, but this is the recommended design when dealing with
|
||||||
|
probabilistic models in PyTorch as generators are *random states* that are advanced and can thus be
|
||||||
|
passed to multiple pipelines in a sequence.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Great! Now, we know how to write reproducible pipelines, but it gets a bit trickier since the above example only runs on the CPU. How do we also achieve reproducibility on GPU?
|
||||||
|
In short, one should not expect full reproducibility across different hardware when running pipelines on GPU
|
||||||
|
as matrix multiplications are less deterministic on GPU than on CPU and diffusion pipelines tend to require
|
||||||
|
a lot of matrix multiplications. Let's see what we can do to keep the randomness within limits across
|
||||||
|
different GPU hardware.
|
||||||
|
|
||||||
|
To achieve maximum speed performance, it is recommended to create the generator directly on GPU when running
|
||||||
|
the pipeline on GPU:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DDIMPipeline
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
|
# load model and scheduler
|
||||||
|
ddim = DDIMPipeline.from_pretrained(model_id)
|
||||||
|
ddim.to("cuda")
|
||||||
|
|
||||||
|
# create a generator for reproducibility
|
||||||
|
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||||
|
|
||||||
|
# run pipeline for just two steps and return numpy tensor
|
||||||
|
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
|
||||||
|
print(np.abs(image).sum())
|
||||||
|
```
|
||||||
|
|
||||||
|
Running the above now prints a value of 1389.8634 - even though we're using the exact same seed!
|
||||||
|
This is unfortunate as it means we cannot reproduce the results we achieved on GPU, also on CPU.
|
||||||
|
Nevertheless, it should be expected since the GPU uses a different random number generator than the CPU.
|
||||||
|
|
||||||
|
To circumvent this problem, we created a [`randn_tensor`](#diffusers.utils.randn_tensor) function, which can create random noise
|
||||||
|
on the CPU and then move the tensor to GPU if necessary. The function is used everywhere inside the pipelines allowing the user to **always** pass a CPU generator even if the pipeline is run on GPU:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DDIMPipeline
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
|
# load model and scheduler
|
||||||
|
ddim = DDIMPipeline.from_pretrained(model_id)
|
||||||
|
ddim.to("cuda")
|
||||||
|
|
||||||
|
# create a generator for reproducibility
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
|
||||||
|
# run pipeline for just two steps and return numpy tensor
|
||||||
|
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
|
||||||
|
print(np.abs(image).sum())
|
||||||
|
```
|
||||||
|
|
||||||
|
Running the above now prints a value of 1491.1713, much closer to the value of 1491.1711 when
|
||||||
|
the pipeline is fully run on the CPU.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
As a consequence, we recommend always passing a CPU generator if Reproducibility is important.
|
||||||
|
The loss of performance is often neglectable, but one can be sure to generate much more similar
|
||||||
|
values than if the pipeline would have been run on CPU.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Finally, we noticed that more complex pipelines, such as [`UnCLIPPipeline`] are often extremely
|
||||||
|
susceptible to precision error propagation and thus one cannot expect even similar results across
|
||||||
|
different GPU hardware or PyTorch versions. In such cases, one has to make sure to run
|
||||||
|
exactly the same hardware and PyTorch version for full Reproducibility.
|
||||||
|
|
||||||
|
## Randomness utilities
|
||||||
|
|
||||||
|
### randn_tensor
|
||||||
|
[[autodoc]] diffusers.utils.randn_tensor
|
|
@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...schedulers import DDIMScheduler
|
from ...schedulers import DDIMScheduler
|
||||||
from ...utils import deprecate, randn_tensor
|
from ...utils import randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,24 +78,6 @@ class DDIMPipeline(DiffusionPipeline):
|
||||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (
|
|
||||||
generator is not None
|
|
||||||
and isinstance(generator, torch.Generator)
|
|
||||||
and generator.device.type != self.device.type
|
|
||||||
and self.device.type != "mps"
|
|
||||||
):
|
|
||||||
message = (
|
|
||||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
|
||||||
f"device `{self.device}`, so the `generator` will be ignored. "
|
|
||||||
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
|
|
||||||
)
|
|
||||||
deprecate(
|
|
||||||
"generator.device == 'cpu'",
|
|
||||||
"0.13.0",
|
|
||||||
message,
|
|
||||||
)
|
|
||||||
generator = None
|
|
||||||
|
|
||||||
# Sample gaussian noise to begin loop
|
# Sample gaussian noise to begin loop
|
||||||
if isinstance(self.unet.sample_size, int):
|
if isinstance(self.unet.sample_size, int):
|
||||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||||
|
|
|
@ -76,6 +76,7 @@ if is_torch_available():
|
||||||
load_numpy,
|
load_numpy,
|
||||||
nightly,
|
nightly,
|
||||||
parse_flag_from_env,
|
parse_flag_from_env,
|
||||||
|
print_tensor_test,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_all_close,
|
torch_all_close,
|
||||||
|
|
|
@ -8,7 +8,7 @@ import urllib.parse
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -45,6 +45,21 @@ def torch_all_close(a, b, *args, **kwargs):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"):
|
||||||
|
test_name = os.environ.get("PYTEST_CURRENT_TEST")
|
||||||
|
if not torch.is_tensor(tensor):
|
||||||
|
tensor = torch.from_numpy(tensor)
|
||||||
|
|
||||||
|
tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
|
||||||
|
# format is usually:
|
||||||
|
# expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161])
|
||||||
|
output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array")
|
||||||
|
test_file, test_class, test_fn = test_name.split("::")
|
||||||
|
test_fn = test_fn.split()[0]
|
||||||
|
with open(filename, "a") as f:
|
||||||
|
print(";".join([test_file, test_class, test_fn, output_str]), file=f)
|
||||||
|
|
||||||
|
|
||||||
def get_tests_dir(append_path=None):
|
def get_tests_dir(append_path=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -150,9 +165,13 @@ def require_onnxruntime(test_case):
|
||||||
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
|
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray:
|
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
|
||||||
if isinstance(arry, str):
|
if isinstance(arry, str):
|
||||||
if arry.startswith("http://") or arry.startswith("https://"):
|
# local_path = "/home/patrick_huggingface_co/"
|
||||||
|
if local_path is not None:
|
||||||
|
# local_path can be passed to correct images of tests
|
||||||
|
return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]]))
|
||||||
|
elif arry.startswith("http://") or arry.startswith("https://"):
|
||||||
response = requests.get(arry)
|
response = requests.get(arry)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
arry = np.load(BytesIO(response.content))
|
arry = np.load(BytesIO(response.content))
|
||||||
|
|
|
@ -166,7 +166,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
def get_generator(self, seed=0):
|
def get_generator(self, seed=0):
|
||||||
if torch_device == "mps":
|
if torch_device == "mps":
|
||||||
return torch.Generator().manual_seed(seed)
|
return torch.manual_seed(seed)
|
||||||
return torch.Generator(device=torch_device).manual_seed(seed)
|
return torch.Generator(device=torch_device).manual_seed(seed)
|
||||||
|
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
|
|
|
@ -188,6 +188,7 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
expected_slice = np.array(
|
expected_slice = np.array(
|
||||||
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
|
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
|
||||||
|
@ -207,20 +208,16 @@ class AltDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
alt_pipe.set_progress_bar_config(disable=None)
|
alt_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
with torch.autocast("cuda"):
|
output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")
|
||||||
output = alt_pipe(
|
|
||||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
|
|
||||||
)
|
|
||||||
|
|
||||||
image = output.images
|
image = output.images
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array(
|
expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586])
|
||||||
[0.8720703, 0.87109375, 0.87402344, 0.87109375, 0.8779297, 0.8925781, 0.8823242, 0.8808594, 0.8613281]
|
|
||||||
)
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_alt_diffusion_fast_ddim(self):
|
def test_alt_diffusion_fast_ddim(self):
|
||||||
|
@ -231,44 +228,14 @@ class AltDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
alt_pipe.set_progress_bar_config(disable=None)
|
alt_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
|
||||||
output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
|
|
||||||
image = output.images
|
image = output.images
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array(
|
expected_slice = np.array([0.4019, 0.4052, 0.3810, 0.4119, 0.3916, 0.3982, 0.4651, 0.4195, 0.5323])
|
||||||
[0.9267578, 0.9301758, 0.9013672, 0.9345703, 0.92578125, 0.94433594, 0.9423828, 0.9423828, 0.9160156]
|
|
||||||
)
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_alt_diffusion_text2img_pipeline_fp16(self):
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
model_id = "BAAI/AltDiffusion"
|
|
||||||
pipe = AltDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
|
|
||||||
pipe = pipe.to(torch_device)
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
prompt = "a photograph of an astronaut riding a horse"
|
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
output_chunked = pipe(
|
|
||||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
|
||||||
)
|
|
||||||
image_chunked = output_chunked.images
|
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
with torch.autocast(torch_device):
|
|
||||||
output = pipe(
|
|
||||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
|
||||||
)
|
|
||||||
image = output.images
|
|
||||||
|
|
||||||
# Make sure results are close enough
|
|
||||||
diff = np.abs(image_chunked.flatten() - image.flatten())
|
|
||||||
# They ARE different since ops are not run always at the same precision
|
|
||||||
# however, they should be extremely close.
|
|
||||||
assert diff.mean() < 2e-2
|
|
||||||
|
|
|
@ -162,6 +162,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
|
||||||
expected_slice = np.array(
|
expected_slice = np.array(
|
||||||
[0.41293705, 0.38656747, 0.40876025, 0.4782187, 0.4656803, 0.41394007, 0.4142093, 0.47150758, 0.4570448]
|
[0.41293705, 0.38656747, 0.40876025, 0.4782187, 0.4656803, 0.41394007, 0.4142093, 0.47150758, 0.4570448]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1.5e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1.5e-3
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1.5e-3
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1.5e-3
|
||||||
|
|
||||||
|
@ -196,7 +197,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
|
||||||
alt_pipe.set_progress_bar_config(disable=None)
|
alt_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = alt_pipe(
|
image = alt_pipe(
|
||||||
[prompt],
|
[prompt],
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -227,7 +228,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "A fantasy landscape, trending on artstation"
|
prompt = "A fantasy landscape, trending on artstation"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=init_image,
|
image=init_image,
|
||||||
|
@ -241,7 +242,8 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
|
||||||
image_slice = image[255:258, 383:386, -1]
|
image_slice = image[255:258, 383:386, -1]
|
||||||
|
|
||||||
assert image.shape == (504, 760, 3)
|
assert image.shape == (504, 760, 3)
|
||||||
expected_slice = np.array([0.3252, 0.3340, 0.3418, 0.3263, 0.3346, 0.3300, 0.3163, 0.3470, 0.3427])
|
expected_slice = np.array([0.9358, 0.9397, 0.9599, 0.9901, 1.0000, 1.0000, 0.9882, 1.0000, 1.0000])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
@ -275,7 +277,7 @@ class AltDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "A fantasy landscape, trending on artstation"
|
prompt = "A fantasy landscape, trending on artstation"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=init_image,
|
image=init_image,
|
||||||
|
|
|
@ -119,6 +119,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
||||||
image_from_tuple_slice = np.frombuffer(image_from_tuple.tobytes(), dtype="uint8")[:10]
|
image_from_tuple_slice = np.frombuffer(image_from_tuple.tobytes(), dtype="uint8")[:10]
|
||||||
expected_slice = np.array([255, 255, 255, 0, 181, 0, 124, 0, 15, 255])
|
expected_slice = np.array([255, 255, 255, 0, 181, 0, 124, 0, 15, 255])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
|
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() == 0
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() == 0
|
||||||
|
|
||||||
|
@ -142,6 +143,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
||||||
expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121])
|
expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
|
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
|
||||||
|
|
||||||
dummy_unet_condition = self.dummy_unet_condition
|
dummy_unet_condition = self.dummy_unet_condition
|
||||||
|
@ -155,6 +157,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
image = output.images[0]
|
image = output.images[0]
|
||||||
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
||||||
expected_slice = np.array([120, 139, 147, 123, 124, 96, 115, 121, 126, 144])
|
expected_slice = np.array([120, 139, 147, 123, 124, 96, 115, 121, 126, 144])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
|
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -183,4 +186,5 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||||
assert image.height == pipe.unet.sample_size[0] and image.width == pipe.unet.sample_size[1]
|
assert image.height == pipe.unet.sample_size[0] and image.width == pipe.unet.sample_size[1]
|
||||||
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
|
||||||
expected_slice = np.array([151, 167, 154, 144, 122, 134, 121, 105, 70, 26])
|
expected_slice = np.array([151, 167, 154, 144, 122, 134, 121, 105, 70, 26])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
|
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
|
||||||
|
|
|
@ -104,14 +104,15 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
|
output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
|
||||||
audio = output.audios
|
audio = output.audios
|
||||||
|
|
||||||
audio_slice = audio[0, -3:, -3:]
|
audio_slice = audio[0, -3:, -3:]
|
||||||
|
|
||||||
assert audio.shape == (1, 2, pipe.unet.sample_size)
|
assert audio.shape == (1, 2, pipe.unet.sample_size)
|
||||||
expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487])
|
expected_slice = np.array([-0.0192, -0.0231, -0.0318, -0.0059, 0.0002, -0.0020])
|
||||||
|
|
||||||
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_dance_diffusion_fp16(self):
|
def test_dance_diffusion_fp16(self):
|
||||||
|
@ -121,12 +122,13 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.Generator(device=device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
|
output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
|
||||||
audio = output.audios
|
audio = output.audios
|
||||||
|
|
||||||
audio_slice = audio[0, -3:, -3:]
|
audio_slice = audio[0, -3:, -3:]
|
||||||
|
|
||||||
assert audio.shape == (1, 2, pipe.unet.sample_size)
|
assert audio.shape == (1, 2, pipe.unet.sample_size)
|
||||||
expected_slice = np.array([-0.1693, -0.1698, -0.1447, -0.3044, -0.3203, -0.2937])
|
expected_slice = np.array([-0.0367, -0.0488, -0.0771, -0.0525, -0.0444, -0.0341])
|
||||||
|
|
||||||
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -82,25 +82,6 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class DDIMPipelineIntegrationTests(unittest.TestCase):
|
class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_ema_bedroom(self):
|
|
||||||
model_id = "google/ddpm-ema-bedroom-256"
|
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id)
|
|
||||||
scheduler = DDIMScheduler.from_pretrained(model_id)
|
|
||||||
|
|
||||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
|
||||||
ddpm.to(torch_device)
|
|
||||||
ddpm.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
image = ddpm(generator=generator, output_type="numpy").images
|
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
|
||||||
|
|
||||||
assert image.shape == (1, 256, 256, 3)
|
|
||||||
expected_slice = np.array([0.1546, 0.1561, 0.1595, 0.1564, 0.1569, 0.1585, 0.1554, 0.1550, 0.1575])
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
|
|
||||||
def test_inference_cifar10(self):
|
def test_inference_cifar10(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
|
@ -111,11 +92,32 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||||
ddim.to(torch_device)
|
ddim.to(torch_device)
|
||||||
ddim.set_progress_bar_config(disable=None)
|
ddim.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
|
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.2060, 0.2042, 0.2022, 0.2193, 0.2146, 0.2110, 0.2471, 0.2446, 0.2388])
|
expected_slice = np.array([0.1723, 0.1617, 0.1600, 0.1626, 0.1497, 0.1513, 0.1505, 0.1442, 0.1453])
|
||||||
|
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_inference_ema_bedroom(self):
|
||||||
|
model_id = "google/ddpm-ema-bedroom-256"
|
||||||
|
|
||||||
|
unet = UNet2DModel.from_pretrained(model_id)
|
||||||
|
scheduler = DDIMScheduler.from_pretrained(model_id)
|
||||||
|
|
||||||
|
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
ddpm.to(torch_device)
|
||||||
|
ddpm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = ddpm(generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 256, 256, 3)
|
||||||
|
expected_slice = np.array([0.0060, 0.0201, 0.0344, 0.0024, 0.0018, 0.0002, 0.0022, 0.0000, 0.0069])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -63,6 +63,7 @@ class DDPMPipelineFastTests(unittest.TestCase):
|
||||||
expected_slice = np.array(
|
expected_slice = np.array(
|
||||||
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
|
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
@ -79,14 +80,10 @@ class DDPMPipelineFastTests(unittest.TestCase):
|
||||||
if torch_device == "mps":
|
if torch_device == "mps":
|
||||||
_ = ddpm(num_inference_steps=1)
|
_ = ddpm(num_inference_steps=1)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||||
|
|
||||||
generator = generator.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
|
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
@ -108,14 +105,10 @@ class DDPMPipelineFastTests(unittest.TestCase):
|
||||||
if torch_device == "mps":
|
if torch_device == "mps":
|
||||||
_ = ddpm(num_inference_steps=1)
|
_ = ddpm(num_inference_steps=1)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||||
|
|
||||||
generator = generator.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0]
|
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0]
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
@ -139,11 +132,12 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||||
ddpm.to(torch_device)
|
ddpm.to(torch_device)
|
||||||
ddpm.set_progress_bar_config(disable=None)
|
ddpm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = ddpm(generator=generator, output_type="numpy").images
|
image = ddpm(generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
|
expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -114,15 +114,14 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
|
||||||
assert np.abs((expected_image - image).max()) < 1e-3
|
assert np.abs((expected_image - image).max()) < 1e-3
|
||||||
|
|
||||||
def test_dit_512_fp16(self):
|
def test_dit_512_fp16(self):
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
|
|
||||||
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512", torch_dtype=torch.float16)
|
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512", torch_dtype=torch.float16)
|
||||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
pipe.to("cuda")
|
pipe.to("cuda")
|
||||||
|
|
||||||
words = ["vase", "umbrella", "white shark", "white wolf"]
|
words = ["vase", "umbrella"]
|
||||||
ids = pipe.get_label_ids(words)
|
ids = pipe.get_label_ids(words)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
images = pipe(ids, generator=generator, num_inference_steps=25, output_type="np").images
|
images = pipe(ids, generator=generator, num_inference_steps=25, output_type="np").images
|
||||||
|
|
||||||
for word, image in zip(words, images):
|
for word, image in zip(words, images):
|
||||||
|
@ -130,4 +129,5 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
|
||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||||
f"/dit/{word}_fp16.npy"
|
f"/dit/{word}_fp16.npy"
|
||||||
)
|
)
|
||||||
assert np.abs((expected_image - image).max()) < 1e-2
|
|
||||||
|
assert np.abs((expected_image - image).max()) < 7.5e-1
|
||||||
|
|
|
@ -59,6 +59,7 @@ class KarrasVePipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
@ -81,4 +82,5 @@ class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
assert image.shape == (1, 256, 256, 3)
|
assert image.shape == (1, 256, 256, 3)
|
||||||
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
|
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -126,7 +126,7 @@ class LDMTextToImagePipelineSlowTests(unittest.TestCase):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
|
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
|
||||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||||
inputs = {
|
inputs = {
|
||||||
|
@ -162,7 +162,7 @@ class LDMTextToImagePipelineNightlyTests(unittest.TestCase):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
|
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
|
||||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||||
inputs = {
|
inputs = {
|
||||||
|
|
|
@ -83,6 +83,7 @@ class LDMSuperResolutionPipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
|
expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||||
|
@ -101,8 +102,7 @@ class LDMSuperResolutionPipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
init_image = self.dummy_image.to(torch_device)
|
init_image = self.dummy_image.to(torch_device)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
image = ldm(init_image, num_inference_steps=2, output_type="numpy").images
|
||||||
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
|
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
|
|
||||||
|
@ -121,11 +121,12 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
|
||||||
ldm.to(torch_device)
|
ldm.to(torch_device)
|
||||||
ldm.set_progress_bar_config(disable=None)
|
ldm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = ldm(image=init_image, generator=generator, num_inference_steps=20, output_type="numpy").images
|
image = ldm(image=init_image, generator=generator, num_inference_steps=20, output_type="numpy").images
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 256, 256, 3)
|
assert image.shape == (1, 256, 256, 3)
|
||||||
expected_slice = np.array([0.7418, 0.7472, 0.7424, 0.7422, 0.7463, 0.726, 0.7382, 0.7248, 0.6828])
|
expected_slice = np.array([0.7644, 0.7679, 0.7642, 0.7633, 0.7666, 0.7560, 0.7425, 0.7257, 0.6907])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -96,6 +96,7 @@ class LDMPipelineFastTests(unittest.TestCase):
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
|
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
|
||||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
||||||
|
|
||||||
|
@ -116,4 +117,5 @@ class LDMPipelineIntegrationTests(unittest.TestCase):
|
||||||
assert image.shape == (1, 256, 256, 3)
|
assert image.shape == (1, 256, 256, 3)
|
||||||
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
|
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
|
||||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||||
|
|
|
@ -205,7 +205,7 @@ class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(321)
|
generator = torch.manual_seed(321)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
image=init_image,
|
image=init_image,
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
|
@ -221,7 +221,6 @@ class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array(
|
expected_slice = np.array([0.4834, 0.4811, 0.4874, 0.5122, 0.5081, 0.5144, 0.5291, 0.5290, 0.5374])
|
||||||
[0.47455794, 0.47086594, 0.47683704, 0.51024145, 0.5064255, 0.5123164, 0.532502, 0.5328063, 0.5428694]
|
|
||||||
)
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -59,6 +59,7 @@ class PNDMPipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
@ -82,4 +83,5 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
|
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -81,6 +81,7 @@ class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274])
|
expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,7 +114,7 @@ class RepaintPipelineNightlyTests(unittest.TestCase):
|
||||||
repaint.set_progress_bar_config(disable=None)
|
repaint.set_progress_bar_config(disable=None)
|
||||||
repaint.enable_attention_slicing()
|
repaint.enable_attention_slicing()
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = repaint(
|
output = repaint(
|
||||||
original_image,
|
original_image,
|
||||||
mask_image,
|
mask_image,
|
||||||
|
|
|
@ -61,6 +61,7 @@ class ScoreSdeVeipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
@ -86,4 +87,5 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
||||||
assert image.shape == (1, 256, 256, 3)
|
assert image.shape == (1, 256, 256, 3)
|
||||||
|
|
||||||
expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0])
|
expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -182,7 +182,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
source_prompt = "A black colored car"
|
source_prompt = "A black colored car"
|
||||||
prompt = "A blue colored car"
|
prompt = "A blue colored car"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
source_prompt=source_prompt,
|
source_prompt=source_prompt,
|
||||||
|
@ -221,7 +221,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
source_prompt = "A black colored car"
|
source_prompt = "A black colored car"
|
||||||
prompt = "A blue colored car"
|
prompt = "A blue colored car"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
source_prompt=source_prompt,
|
source_prompt=source_prompt,
|
||||||
|
|
|
@ -60,6 +60,7 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.65072, 0.58492, 0.48219, 0.55521, 0.53180, 0.55939, 0.50697, 0.39800, 0.46455])
|
expected_slice = np.array([0.65072, 0.58492, 0.48219, 0.55521, 0.53180, 0.55939, 0.50697, 0.39800, 0.46455])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_pipeline_pndm(self):
|
def test_pipeline_pndm(self):
|
||||||
|
@ -73,6 +74,7 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.65863, 0.59425, 0.49326, 0.56313, 0.53875, 0.56627, 0.51065, 0.39777, 0.46330])
|
expected_slice = np.array([0.65863, 0.59425, 0.49326, 0.56313, 0.53875, 0.56627, 0.51065, 0.39777, 0.46330])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_pipeline_lms(self):
|
def test_pipeline_lms(self):
|
||||||
|
@ -86,6 +88,7 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.53755, 0.60786, 0.47402, 0.49488, 0.51869, 0.49819, 0.47985, 0.38957, 0.44279])
|
expected_slice = np.array([0.53755, 0.60786, 0.47402, 0.49488, 0.51869, 0.49819, 0.47985, 0.38957, 0.44279])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_pipeline_euler(self):
|
def test_pipeline_euler(self):
|
||||||
|
@ -99,6 +102,7 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.53755, 0.60786, 0.47402, 0.49488, 0.51869, 0.49819, 0.47985, 0.38957, 0.44279])
|
expected_slice = np.array([0.53755, 0.60786, 0.47402, 0.49488, 0.51869, 0.49819, 0.47985, 0.38957, 0.44279])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_pipeline_euler_ancestral(self):
|
def test_pipeline_euler_ancestral(self):
|
||||||
|
@ -112,6 +116,7 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.53817, 0.60812, 0.47384, 0.49530, 0.51894, 0.49814, 0.47984, 0.38958, 0.44271])
|
expected_slice = np.array([0.53817, 0.60812, 0.47384, 0.49530, 0.51894, 0.49814, 0.47984, 0.38958, 0.44271])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_pipeline_dpm_multistep(self):
|
def test_pipeline_dpm_multistep(self):
|
||||||
|
@ -125,6 +130,7 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.53895, 0.60808, 0.47933, 0.49608, 0.51886, 0.49950, 0.48053, 0.38957, 0.44200])
|
expected_slice = np.array([0.53895, 0.60808, 0.47933, 0.49608, 0.51886, 0.49950, 0.48053, 0.38957, 0.44200])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
|
||||||
|
@ -169,6 +175,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.0452, 0.0390, 0.0087, 0.0350, 0.0617, 0.0364, 0.0544, 0.0523, 0.0720])
|
expected_slice = np.array([0.0452, 0.0390, 0.0087, 0.0350, 0.0617, 0.0364, 0.0544, 0.0523, 0.0720])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_inference_ddim(self):
|
def test_inference_ddim(self):
|
||||||
|
@ -194,6 +201,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.2867, 0.1974, 0.1481, 0.7294, 0.7251, 0.6667, 0.4194, 0.5642, 0.6486])
|
expected_slice = np.array([0.2867, 0.1974, 0.1481, 0.7294, 0.7251, 0.6667, 0.4194, 0.5642, 0.6486])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_inference_k_lms(self):
|
def test_inference_k_lms(self):
|
||||||
|
@ -219,6 +227,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.2306, 0.1959, 0.1593, 0.6549, 0.6394, 0.5408, 0.5065, 0.6010, 0.6161])
|
expected_slice = np.array([0.2306, 0.1959, 0.1593, 0.6549, 0.6394, 0.5408, 0.5065, 0.6010, 0.6161])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_intermediate_state(self):
|
def test_intermediate_state(self):
|
||||||
|
@ -234,6 +243,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
expected_slice = np.array(
|
expected_slice = np.array(
|
||||||
[-0.6772, -0.3835, -1.2456, 0.1905, -1.0974, 0.6967, -1.9353, 0.0178, 1.0167]
|
[-0.6772, -0.3835, -1.2456, 0.1905, -1.0974, 0.6967, -1.9353, 0.0178, 1.0167]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
elif step == 5:
|
elif step == 5:
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
|
@ -241,6 +251,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
expected_slice = np.array(
|
expected_slice = np.array(
|
||||||
[-0.3351, 0.2241, -0.1837, -0.2325, -0.6577, 0.3393, -0.0241, 0.5899, 1.3875]
|
[-0.3351, 0.2241, -0.1837, -0.2325, -0.6577, 0.3393, -0.0241, 0.5899, 1.3875]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
test_callback_fn.has_been_called = False
|
test_callback_fn.has_been_called = False
|
||||||
|
|
|
@ -82,6 +82,7 @@ class OnnxStableDiffusionImg2ImgPipelineFastTests(OnnxPipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.61710, 0.53390, 0.49310, 0.55622, 0.50982, 0.58240, 0.50716, 0.38629, 0.46856])
|
expected_slice = np.array([0.61710, 0.53390, 0.49310, 0.55622, 0.50982, 0.58240, 0.50716, 0.38629, 0.46856])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||||
|
|
||||||
def test_pipeline_lms(self):
|
def test_pipeline_lms(self):
|
||||||
|
@ -98,6 +99,7 @@ class OnnxStableDiffusionImg2ImgPipelineFastTests(OnnxPipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.52761, 0.59977, 0.49033, 0.49619, 0.54282, 0.50311, 0.47600, 0.40918, 0.45203])
|
expected_slice = np.array([0.52761, 0.59977, 0.49033, 0.49619, 0.54282, 0.50311, 0.47600, 0.40918, 0.45203])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||||
|
|
||||||
def test_pipeline_euler(self):
|
def test_pipeline_euler(self):
|
||||||
|
@ -111,6 +113,7 @@ class OnnxStableDiffusionImg2ImgPipelineFastTests(OnnxPipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.52911, 0.60004, 0.49229, 0.49805, 0.54502, 0.50680, 0.47777, 0.41028, 0.45304])
|
expected_slice = np.array([0.52911, 0.60004, 0.49229, 0.49805, 0.54502, 0.50680, 0.47777, 0.41028, 0.45304])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||||
|
|
||||||
def test_pipeline_euler_ancestral(self):
|
def test_pipeline_euler_ancestral(self):
|
||||||
|
@ -124,6 +127,7 @@ class OnnxStableDiffusionImg2ImgPipelineFastTests(OnnxPipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.52911, 0.60004, 0.49229, 0.49805, 0.54502, 0.50680, 0.47777, 0.41028, 0.45304])
|
expected_slice = np.array([0.52911, 0.60004, 0.49229, 0.49805, 0.54502, 0.50680, 0.47777, 0.41028, 0.45304])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||||
|
|
||||||
def test_pipeline_dpm_multistep(self):
|
def test_pipeline_dpm_multistep(self):
|
||||||
|
@ -137,6 +141,7 @@ class OnnxStableDiffusionImg2ImgPipelineFastTests(OnnxPipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert image.shape == (1, 128, 128, 3)
|
assert image.shape == (1, 128, 128, 3)
|
||||||
expected_slice = np.array([0.65331, 0.58277, 0.48204, 0.56059, 0.53665, 0.56235, 0.50969, 0.40009, 0.46552])
|
expected_slice = np.array([0.65331, 0.58277, 0.48204, 0.56059, 0.53665, 0.56235, 0.50969, 0.40009, 0.46552])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||||
|
|
||||||
|
|
||||||
|
@ -195,6 +200,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
assert images.shape == (1, 512, 768, 3)
|
assert images.shape == (1, 512, 768, 3)
|
||||||
expected_slice = np.array([0.4909, 0.5059, 0.5372, 0.4623, 0.4876, 0.5049, 0.4820, 0.4956, 0.5019])
|
expected_slice = np.array([0.4909, 0.5059, 0.5372, 0.4623, 0.4876, 0.5049, 0.4820, 0.4956, 0.5019])
|
||||||
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||||
|
|
||||||
def test_inference_k_lms(self):
|
def test_inference_k_lms(self):
|
||||||
|
@ -235,4 +241,5 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
assert images.shape == (1, 512, 768, 3)
|
assert images.shape == (1, 512, 768, 3)
|
||||||
expected_slice = np.array([0.8043, 0.926, 0.9581, 0.8119, 0.8954, 0.913, 0.7209, 0.7463, 0.7431])
|
expected_slice = np.array([0.8043, 0.926, 0.9581, 0.8119, 0.8954, 0.913, 0.7209, 0.7463, 0.7431])
|
||||||
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||||
|
|
|
@ -94,6 +94,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
assert images.shape == (1, 512, 512, 3)
|
assert images.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.2514, 0.3007, 0.3517, 0.1790, 0.2382, 0.3167, 0.1944, 0.2273, 0.2464])
|
expected_slice = np.array([0.2514, 0.3007, 0.3517, 0.1790, 0.2382, 0.3167, 0.1944, 0.2273, 0.2464])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_inference_k_lms(self):
|
def test_inference_k_lms(self):
|
||||||
|
@ -136,4 +137,5 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
assert images.shape == (1, 512, 512, 3)
|
assert images.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.0086, 0.0077, 0.0083, 0.0093, 0.0107, 0.0139, 0.0094, 0.0097, 0.0125])
|
expected_slice = np.array([0.0086, 0.0077, 0.0083, 0.0093, 0.0107, 0.0139, 0.0094, 0.0097, 0.0125])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
|
@ -244,6 +244,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.5094, 0.5674, 0.4667, 0.5125, 0.5696, 0.4674, 0.5277, 0.4964, 0.4945])
|
expected_slice = np.array([0.5094, 0.5674, 0.4667, 0.5125, 0.5696, 0.4674, 0.5277, 0.4964, 0.4945])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_no_safety_checker(self):
|
def test_stable_diffusion_no_safety_checker(self):
|
||||||
|
@ -295,6 +296,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
0.5042197108268738,
|
0.5042197108268738,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_k_euler_ancestral(self):
|
def test_stable_diffusion_k_euler_ancestral(self):
|
||||||
|
@ -325,6 +327,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
0.504422664642334,
|
0.504422664642334,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_k_euler(self):
|
def test_stable_diffusion_k_euler(self):
|
||||||
|
@ -355,6 +358,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
0.5042197108268738,
|
0.5042197108268738,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_vae_slicing(self):
|
def test_stable_diffusion_vae_slicing(self):
|
||||||
|
@ -409,6 +413,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
0.4899061322212219,
|
0.4899061322212219,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_num_images_per_prompt(self):
|
def test_stable_diffusion_num_images_per_prompt(self):
|
||||||
|
@ -519,8 +524,8 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
||||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||||
inputs = {
|
inputs = {
|
||||||
|
@ -657,9 +662,11 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||||
mem_bytes = torch.cuda.max_memory_allocated()
|
mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
assert mem_bytes > 4e9
|
assert mem_bytes > 4e9
|
||||||
# There is a small discrepancy at the image borders vs. a fully batched version.
|
# There is a small discrepancy at the image borders vs. a fully batched version.
|
||||||
assert np.abs(image_sliced - image).max() < 4e-3
|
assert np.abs(image_sliced - image).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_fp16_vs_autocast(self):
|
def test_stable_diffusion_fp16_vs_autocast(self):
|
||||||
|
# this test makes sure that the original model with autocast
|
||||||
|
# and the new model with fp16 yield the same result
|
||||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
|
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -688,14 +695,20 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161])
|
expected_slice = np.array(
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
[-0.5693, -0.3018, -0.9746, 0.0518, -0.8770, 0.7559, -1.7402, 0.1022, 1.1582]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
elif step == 2:
|
elif step == 2:
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-0.1885, -0.3022, -1.012, -0.514, -0.477, 0.6143, -0.9336, 0.6553, 1.453])
|
expected_slice = np.array(
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
[-0.1958, -0.2993, -1.0166, -0.5005, -0.4810, 0.6162, -0.9492, 0.6621, 1.4492]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
|
|
||||||
callback_fn.has_been_called = False
|
callback_fn.has_been_called = False
|
||||||
|
|
||||||
|
@ -750,8 +763,8 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
||||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||||
inputs = {
|
inputs = {
|
||||||
|
|
|
@ -117,6 +117,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.5167, 0.5746, 0.4835, 0.4914, 0.5605, 0.4691, 0.5201, 0.4898, 0.4958])
|
expected_slice = np.array([0.5167, 0.5746, 0.4835, 0.4914, 0.5605, 0.4691, 0.5201, 0.4898, 0.4958])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img_variation_multiple_images(self):
|
def test_stable_diffusion_img_variation_multiple_images(self):
|
||||||
|
@ -136,6 +137,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||||
|
|
||||||
assert image.shape == (2, 64, 64, 3)
|
assert image.shape == (2, 64, 64, 3)
|
||||||
expected_slice = np.array([0.6568, 0.5470, 0.5684, 0.5444, 0.5945, 0.6221, 0.5508, 0.5531, 0.5263])
|
expected_slice = np.array([0.6568, 0.5470, 0.5684, 0.5444, 0.5945, 0.6221, 0.5508, 0.5531, 0.5263])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
|
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
|
||||||
|
@ -183,8 +185,8 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||||
"/stable_diffusion_imgvar/input_image_vermeer.png"
|
"/stable_diffusion_imgvar/input_image_vermeer.png"
|
||||||
|
@ -227,13 +229,17 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-0.1572, 0.2837, -0.798, -0.1201, -1.304, 0.7754, -2.12, 0.0443, 1.627])
|
expected_slice = np.array(
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
[-0.1621, 0.2837, -0.7979, -0.1221, -1.3057, 0.7681, -2.1191, 0.0464, 1.6309]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
elif step == 2:
|
elif step == 2:
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([0.6143, 1.734, 1.158, -2.145, -1.926, 0.748, -0.7246, 0.994, 1.539])
|
expected_slice = np.array([0.6299, 1.7500, 1.1992, -2.1582, -1.8994, 0.7334, -0.7090, 1.0137, 1.5273])
|
||||||
|
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
|
|
||||||
callback_fn.has_been_called = False
|
callback_fn.has_been_called = False
|
||||||
|
@ -282,8 +288,8 @@ class StableDiffusionImageVariationPipelineNightlyTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||||
"/stable_diffusion_imgvar/input_image_vermeer.png"
|
"/stable_diffusion_imgvar/input_image_vermeer.png"
|
||||||
|
|
|
@ -119,6 +119,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
|
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_negative_prompt(self):
|
def test_stable_diffusion_img2img_negative_prompt(self):
|
||||||
|
@ -136,6 +137,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365])
|
expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_multiple_init_images(self):
|
def test_stable_diffusion_img2img_multiple_init_images(self):
|
||||||
|
@ -153,6 +155,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
|
|
||||||
assert image.shape == (2, 32, 32, 3)
|
assert image.shape == (2, 32, 32, 3)
|
||||||
expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689])
|
expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_k_lms(self):
|
def test_stable_diffusion_img2img_k_lms(self):
|
||||||
|
@ -171,6 +174,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
|
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_num_images_per_prompt(self):
|
def test_stable_diffusion_img2img_num_images_per_prompt(self):
|
||||||
|
@ -218,8 +222,8 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||||
|
@ -246,7 +250,8 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 768, 3)
|
assert image.shape == (1, 512, 768, 3)
|
||||||
expected_slice = np.array([0.27150, 0.14849, 0.15605, 0.26740, 0.16954, 0.18204, 0.31470, 0.26311, 0.24525])
|
expected_slice = np.array([0.4300, 0.4662, 0.4930, 0.3990, 0.4307, 0.4525, 0.3719, 0.4064, 0.3923])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_k_lms(self):
|
def test_stable_diffusion_img2img_k_lms(self):
|
||||||
|
@ -261,7 +266,8 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 768, 3)
|
assert image.shape == (1, 512, 768, 3)
|
||||||
expected_slice = np.array([0.04890, 0.04862, 0.06422, 0.04655, 0.05108, 0.05307, 0.05926, 0.08759, 0.06852])
|
expected_slice = np.array([0.0389, 0.0346, 0.0415, 0.0290, 0.0218, 0.0210, 0.0408, 0.0567, 0.0271])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_ddim(self):
|
def test_stable_diffusion_img2img_ddim(self):
|
||||||
|
@ -276,7 +282,8 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 768, 3)
|
assert image.shape == (1, 512, 768, 3)
|
||||||
expected_slice = np.array([0.06069, 0.05703, 0.08054, 0.05797, 0.06286, 0.06234, 0.08438, 0.11151, 0.08068])
|
expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_intermediate_state(self):
|
def test_stable_diffusion_img2img_intermediate_state(self):
|
||||||
|
@ -290,14 +297,16 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 96)
|
assert latents.shape == (1, 4, 64, 96)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([0.7705, 0.1045, 0.5, 3.393, 3.723, 4.273, 2.467, 3.486, 1.758])
|
expected_slice = np.array([-0.4958, 0.5107, 1.1045, 2.7539, 4.6680, 3.8320, 1.5049, 1.8633, 2.6523])
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
elif step == 2:
|
elif step == 2:
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 96)
|
assert latents.shape == (1, 4, 64, 96)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([0.765, 0.1047, 0.4973, 3.375, 3.709, 4.258, 2.451, 3.46, 1.755])
|
expected_slice = np.array([-0.4956, 0.5078, 1.0918, 2.7520, 4.6484, 3.8125, 1.5146, 1.8633, 2.6367])
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
|
|
||||||
callback_fn.has_been_called = False
|
callback_fn.has_been_called = False
|
||||||
|
|
||||||
|
@ -352,7 +361,7 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "A fantasy landscape, trending on artstation"
|
prompt = "A fantasy landscape, trending on artstation"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=init_image,
|
image=init_image,
|
||||||
|
@ -366,8 +375,9 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[255:258, 383:386, -1]
|
image_slice = image[255:258, 383:386, -1]
|
||||||
|
|
||||||
assert image.shape == (504, 760, 3)
|
assert image.shape == (504, 760, 3)
|
||||||
expected_slice = np.array([0.7124, 0.7105, 0.6993, 0.7140, 0.7106, 0.6945, 0.7198, 0.7172, 0.7031])
|
expected_slice = np.array([0.9393, 0.9500, 0.9399, 0.9438, 0.9458, 0.9400, 0.9455, 0.9414, 0.9423])
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||||
|
|
||||||
|
|
||||||
@nightly
|
@nightly
|
||||||
|
@ -378,8 +388,8 @@ class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||||
|
|
|
@ -125,6 +125,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
|
expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_image_tensor(self):
|
def test_stable_diffusion_inpaint_image_tensor(self):
|
||||||
|
@ -172,8 +173,8 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||||
|
@ -206,7 +207,8 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.05978, 0.10983, 0.10514, 0.07922, 0.08483, 0.08587, 0.05302, 0.03218, 0.01636])
|
expected_slice = np.array([0.0427, 0.0460, 0.0483, 0.0460, 0.0584, 0.0521, 0.1549, 0.1695, 0.1794])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_fp16(self):
|
def test_stable_diffusion_inpaint_fp16(self):
|
||||||
|
@ -222,8 +224,9 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.06152, 0.11060, 0.10449, 0.07959, 0.08643, 0.08496, 0.05420, 0.03247, 0.01831])
|
expected_slice = np.array([0.1443, 0.1218, 0.1587, 0.1594, 0.1411, 0.1284, 0.1370, 0.1506, 0.2339])
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-2
|
|
||||||
|
assert np.abs(expected_slice - image_slice).max() < 5e-2
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_pndm(self):
|
def test_stable_diffusion_inpaint_pndm(self):
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
|
@ -239,7 +242,8 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.06892, 0.06994, 0.07905, 0.05366, 0.04709, 0.04890, 0.04107, 0.05083, 0.04180])
|
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_k_lms(self):
|
def test_stable_diffusion_inpaint_k_lms(self):
|
||||||
|
@ -256,7 +260,8 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.23513, 0.22413, 0.29442, 0.24243, 0.26214, 0.30329, 0.26431, 0.25025, 0.25197])
|
expected_slice = np.array([0.9314, 0.7575, 0.9432, 0.8885, 0.9028, 0.7298, 0.9811, 0.9667, 0.7633])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
|
def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
|
||||||
|
@ -288,8 +293,8 @@ class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||||
|
|
|
@ -213,6 +213,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
|
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
@ -260,6 +261,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150])
|
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_legacy_num_images_per_prompt(self):
|
def test_stable_diffusion_inpaint_legacy_num_images_per_prompt(self):
|
||||||
|
@ -347,8 +349,8 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||||
|
@ -382,7 +384,8 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.27200, 0.29103, 0.34405, 0.21418, 0.26317, 0.34281, 0.18033, 0.24911, 0.32028])
|
expected_slice = np.array([0.5669, 0.6124, 0.6431, 0.4073, 0.4614, 0.5670, 0.1609, 0.3128, 0.4330])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_legacy_k_lms(self):
|
def test_stable_diffusion_inpaint_legacy_k_lms(self):
|
||||||
|
@ -399,7 +402,8 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.29014, 0.28882, 0.32835, 0.26502, 0.28182, 0.31162, 0.29297, 0.29534, 0.28214])
|
expected_slice = np.array([0.4533, 0.4465, 0.4327, 0.4329, 0.4339, 0.4219, 0.4243, 0.4332, 0.4426])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_legacy_intermediate_state(self):
|
def test_stable_diffusion_inpaint_legacy_intermediate_state(self):
|
||||||
|
@ -413,13 +417,15 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-0.103, 1.415, -0.02197, -0.5107, -0.5903, 0.1953, 0.75, 0.3477, -1.356])
|
expected_slice = np.array([0.5977, 1.5449, 1.0586, -0.3250, 0.7383, -0.0862, 0.4631, -0.2571, -1.1289])
|
||||||
|
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
elif step == 2:
|
elif step == 2:
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([0.4802, 1.154, 0.628, 0.2319, 0.2593, -0.1455, 0.7075, -0.1617, -0.5615])
|
expected_slice = np.array([0.5190, 1.1621, 0.6885, 0.2424, 0.3337, -0.1617, 0.6914, -0.1957, -0.5474])
|
||||||
|
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
callback_fn.has_been_called = False
|
callback_fn.has_been_called = False
|
||||||
|
@ -445,8 +451,8 @@ class StableDiffusionInpaintLegacyPipelineNightlyTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||||
|
|
|
@ -122,6 +122,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unitt
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.7318, 0.3723, 0.4662, 0.623, 0.5770, 0.5014, 0.4281, 0.5550, 0.4813])
|
expected_slice = np.array([0.7318, 0.3723, 0.4662, 0.623, 0.5770, 0.5014, 0.4281, 0.5550, 0.4813])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_negative_prompt(self):
|
def test_stable_diffusion_pix2pix_negative_prompt(self):
|
||||||
|
@ -139,6 +140,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.7323, 0.3688, 0.4611, 0.6255, 0.5746, 0.5017, 0.433, 0.5553, 0.4827])
|
expected_slice = np.array([0.7323, 0.3688, 0.4611, 0.6255, 0.5746, 0.5017, 0.433, 0.5553, 0.4827])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_multiple_init_images(self):
|
def test_stable_diffusion_pix2pix_multiple_init_images(self):
|
||||||
|
@ -161,6 +163,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert image.shape == (2, 32, 32, 3)
|
assert image.shape == (2, 32, 32, 3)
|
||||||
expected_slice = np.array([0.606, 0.5712, 0.5099, 0.598, 0.5805, 0.7205, 0.6793, 0.554, 0.5607])
|
expected_slice = np.array([0.606, 0.5712, 0.5099, 0.598, 0.5805, 0.7205, 0.6793, 0.554, 0.5607])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_euler(self):
|
def test_stable_diffusion_pix2pix_euler(self):
|
||||||
|
@ -182,6 +185,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert image.shape == (1, 32, 32, 3)
|
assert image.shape == (1, 32, 32, 3)
|
||||||
expected_slice = np.array([0.726, 0.3902, 0.4868, 0.585, 0.5672, 0.511, 0.3906, 0.551, 0.4846])
|
expected_slice = np.array([0.726, 0.3902, 0.4868, 0.585, 0.5672, 0.511, 0.3906, 0.551, 0.4846])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_num_images_per_prompt(self):
|
def test_stable_diffusion_pix2pix_num_images_per_prompt(self):
|
||||||
|
@ -259,6 +263,7 @@ class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.5902, 0.6015, 0.6027, 0.5983, 0.6092, 0.6061, 0.5765, 0.5785, 0.5555])
|
expected_slice = np.array([0.5902, 0.6015, 0.6027, 0.5983, 0.6092, 0.6061, 0.5765, 0.5785, 0.5555])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_k_lms(self):
|
def test_stable_diffusion_pix2pix_k_lms(self):
|
||||||
|
@ -276,6 +281,7 @@ class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.6578, 0.6817, 0.6972, 0.6761, 0.6856, 0.6916, 0.6428, 0.6516, 0.6301])
|
expected_slice = np.array([0.6578, 0.6817, 0.6972, 0.6761, 0.6856, 0.6916, 0.6428, 0.6516, 0.6301])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_ddim(self):
|
def test_stable_diffusion_pix2pix_ddim(self):
|
||||||
|
@ -293,6 +299,7 @@ class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.3828, 0.3834, 0.3818, 0.3792, 0.3865, 0.3752, 0.3792, 0.3847, 0.3753])
|
expected_slice = np.array([0.3828, 0.3834, 0.3818, 0.3792, 0.3865, 0.3752, 0.3792, 0.3847, 0.3753])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_intermediate_state(self):
|
def test_stable_diffusion_pix2pix_intermediate_state(self):
|
||||||
|
@ -306,14 +313,16 @@ class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-0.2388, -0.4673, -0.9775, 1.5127, 1.4414, 0.7778, 0.9907, 0.8472, 0.7788])
|
expected_slice = np.array([-0.2463, -0.4644, -0.9756, 1.5176, 1.4414, 0.7866, 0.9897, 0.8521, 0.7983])
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
elif step == 2:
|
elif step == 2:
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-0.2568, -0.4648, -0.9639, 1.5137, 1.4609, 0.7603, 0.9795, 0.8403, 0.7949])
|
expected_slice = np.array([-0.2644, -0.4626, -0.9653, 1.5176, 1.4551, 0.7686, 0.9805, 0.8452, 0.8115])
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
|
|
||||||
callback_fn.has_been_called = False
|
callback_fn.has_been_called = False
|
||||||
|
|
||||||
|
@ -369,5 +378,6 @@ class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):
|
||||||
image_slice = image[255:258, 383:386, -1]
|
image_slice = image[255:258, 383:386, -1]
|
||||||
|
|
||||||
assert image.shape == (504, 504, 3)
|
assert image.shape == (504, 504, 3)
|
||||||
expected_slice = np.array([0.2726, 0.2529, 0.2664, 0.2655, 0.2641, 0.2642, 0.2591, 0.2649, 0.259])
|
expected_slice = np.array([0.2726, 0.2529, 0.2664, 0.2655, 0.2641, 0.2642, 0.2591, 0.2649, 0.2590])
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||||
|
|
|
@ -44,7 +44,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
sd_pipe.set_scheduler("sample_euler")
|
sd_pipe.set_scheduler("sample_euler")
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
|
output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
|
||||||
|
|
||||||
image = output.images
|
image = output.images
|
||||||
|
@ -52,7 +52,8 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
|
expected_slice = np.array([0.0447, 0.0492, 0.0468, 0.0408, 0.0383, 0.0408, 0.0354, 0.0380, 0.0339])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_2(self):
|
def test_stable_diffusion_2(self):
|
||||||
|
@ -63,7 +64,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
sd_pipe.set_scheduler("sample_euler")
|
sd_pipe.set_scheduler("sample_euler")
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
|
output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
|
||||||
|
|
||||||
image = output.images
|
image = output.images
|
||||||
|
@ -71,7 +72,6 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array(
|
expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])
|
||||||
[0.826810, 0.81958747, 0.8510199, 0.8376758, 0.83958465, 0.8682068, 0.84370345, 0.85251087, 0.85884345]
|
|
||||||
)
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
|
|
|
@ -149,6 +149,7 @@ class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.5099, 0.5677, 0.4671, 0.5128, 0.5697, 0.4676, 0.5277, 0.4964, 0.4946])
|
expected_slice = np.array([0.5099, 0.5677, 0.4671, 0.5128, 0.5697, 0.4676, 0.5277, 0.4964, 0.4946])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_k_lms(self):
|
def test_stable_diffusion_k_lms(self):
|
||||||
|
@ -165,6 +166,7 @@ class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043])
|
expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_k_euler_ancestral(self):
|
def test_stable_diffusion_k_euler_ancestral(self):
|
||||||
|
@ -181,6 +183,7 @@ class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.4715, 0.5376, 0.4569, 0.5224, 0.5734, 0.4797, 0.5465, 0.5074, 0.5046])
|
expected_slice = np.array([0.4715, 0.5376, 0.4569, 0.5224, 0.5734, 0.4797, 0.5465, 0.5074, 0.5046])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_k_euler(self):
|
def test_stable_diffusion_k_euler(self):
|
||||||
|
@ -197,6 +200,7 @@ class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043])
|
expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_long_prompt(self):
|
def test_stable_diffusion_long_prompt(self):
|
||||||
|
@ -246,8 +250,8 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
||||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||||
inputs = {
|
inputs = {
|
||||||
|
@ -340,14 +344,20 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-0.3857, -0.4507, -1.167, 0.074, -1.108, 0.7183, -1.822, 0.1915, 1.283])
|
expected_slice = np.array(
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
[-0.3862, -0.4507, -1.1729, 0.0686, -1.1045, 0.7124, -1.8301, 0.1903, 1.2773]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
elif step == 2:
|
elif step == 2:
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 64, 64)
|
assert latents.shape == (1, 4, 64, 64)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([0.268, -0.2095, -0.7744, -0.541, -0.79, 0.3926, -0.7754, 0.465, 1.291])
|
expected_slice = np.array(
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
[0.2720, -0.1863, -0.7383, -0.5029, -0.7534, 0.3970, -0.7646, 0.4468, 1.2686]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
|
|
||||||
callback_fn.has_been_called = False
|
callback_fn.has_been_called = False
|
||||||
|
|
||||||
|
@ -392,8 +402,8 @@ class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
|
||||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||||
inputs = {
|
inputs = {
|
||||||
|
|
|
@ -289,6 +289,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||||
expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
|
expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
|
||||||
else:
|
else:
|
||||||
expected_slice = np.array([0.6854, 0.3740, 0.4857, 0.7130, 0.7403, 0.5536, 0.4829, 0.6182, 0.5053])
|
expected_slice = np.array([0.6854, 0.3740, 0.4857, 0.7130, 0.7403, 0.5536, 0.4829, 0.6182, 0.5053])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_depth2img_negative_prompt(self):
|
def test_stable_diffusion_depth2img_negative_prompt(self):
|
||||||
|
@ -309,6 +310,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||||
expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
|
expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
|
||||||
else:
|
else:
|
||||||
expected_slice = np.array([0.6074, 0.3096, 0.4802, 0.7463, 0.7388, 0.5393, 0.4531, 0.5928, 0.4972])
|
expected_slice = np.array([0.6074, 0.3096, 0.4802, 0.7463, 0.7388, 0.5393, 0.4531, 0.5928, 0.4972])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_depth2img_multiple_init_images(self):
|
def test_stable_diffusion_depth2img_multiple_init_images(self):
|
||||||
|
@ -330,6 +332,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||||
expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
|
expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
|
||||||
else:
|
else:
|
||||||
expected_slice = np.array([0.6681, 0.5023, 0.6611, 0.7605, 0.5724, 0.7959, 0.7240, 0.5871, 0.5383])
|
expected_slice = np.array([0.6681, 0.5023, 0.6611, 0.7605, 0.5724, 0.7959, 0.7240, 0.5871, 0.5383])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_depth2img_num_images_per_prompt(self):
|
def test_stable_diffusion_depth2img_num_images_per_prompt(self):
|
||||||
|
@ -384,6 +387,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||||
expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
|
expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
|
||||||
else:
|
else:
|
||||||
expected_slice = np.array([0.6853, 0.3740, 0.4856, 0.7130, 0.7402, 0.5535, 0.4828, 0.6182, 0.5053])
|
expected_slice = np.array([0.6853, 0.3740, 0.4856, 0.7130, 0.7402, 0.5535, 0.4828, 0.6182, 0.5053])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
@ -395,7 +399,7 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
|
||||||
|
@ -419,12 +423,13 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device)
|
inputs = self.get_inputs()
|
||||||
image = pipe(**inputs).images
|
image = pipe(**inputs).images
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 480, 640, 3)
|
assert image.shape == (1, 480, 640, 3)
|
||||||
expected_slice = np.array([0.75446, 0.74692, 0.75951, 0.81611, 0.80593, 0.79992, 0.90529, 0.87921, 0.86903])
|
expected_slice = np.array([0.9057, 0.9365, 0.9258, 0.8937, 0.8555, 0.8541, 0.8260, 0.7747, 0.7421])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_depth2img_pipeline_k_lms(self):
|
def test_stable_diffusion_depth2img_pipeline_k_lms(self):
|
||||||
|
@ -436,12 +441,13 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device)
|
inputs = self.get_inputs()
|
||||||
image = pipe(**inputs).images
|
image = pipe(**inputs).images
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 480, 640, 3)
|
assert image.shape == (1, 480, 640, 3)
|
||||||
expected_slice = np.array([0.63957, 0.64879, 0.65668, 0.64385, 0.67078, 0.63588, 0.66577, 0.62180, 0.66286])
|
expected_slice = np.array([0.6363, 0.6274, 0.6309, 0.6370, 0.6226, 0.6286, 0.6213, 0.6453, 0.6306])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_depth2img_pipeline_ddim(self):
|
def test_stable_diffusion_depth2img_pipeline_ddim(self):
|
||||||
|
@ -453,12 +459,13 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device)
|
inputs = self.get_inputs()
|
||||||
image = pipe(**inputs).images
|
image = pipe(**inputs).images
|
||||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||||
|
|
||||||
assert image.shape == (1, 480, 640, 3)
|
assert image.shape == (1, 480, 640, 3)
|
||||||
expected_slice = np.array([0.62840, 0.64191, 0.62953, 0.63653, 0.64205, 0.61574, 0.62252, 0.65827, 0.64809])
|
expected_slice = np.array([0.6424, 0.6524, 0.6249, 0.6041, 0.6634, 0.6420, 0.6522, 0.6555, 0.6436])
|
||||||
|
|
||||||
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
assert np.abs(expected_slice - image_slice).max() < 1e-4
|
||||||
|
|
||||||
def test_stable_diffusion_depth2img_intermediate_state(self):
|
def test_stable_diffusion_depth2img_intermediate_state(self):
|
||||||
|
@ -472,14 +479,20 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 60, 80)
|
assert latents.shape == (1, 4, 60, 80)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-1.148, -0.2079, -0.622, -2.477, -2.348, 0.3828, -2.055, -1.569, -1.526])
|
expected_slice = np.array(
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
[-0.7168, -1.5137, -0.1418, -2.9219, -2.7266, -2.4414, -2.1035, -3.0078, -1.7051]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
elif step == 2:
|
elif step == 2:
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 60, 80)
|
assert latents.shape == (1, 4, 60, 80)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([-1.145, -0.2063, -0.6216, -2.469, -2.344, 0.3794, -2.05, -1.57, -1.521])
|
expected_slice = np.array(
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
[-0.7109, -1.5068, -0.1403, -2.9160, -2.7207, -2.4414, -2.1035, -3.0059, -1.7090]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
|
|
||||||
callback_fn.has_been_called = False
|
callback_fn.has_been_called = False
|
||||||
|
|
||||||
|
@ -490,7 +503,7 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
inputs = self.get_inputs(dtype=torch.float16)
|
||||||
pipe(**inputs, callback=callback_fn, callback_steps=1)
|
pipe(**inputs, callback=callback_fn, callback_steps=1)
|
||||||
assert callback_fn.has_been_called
|
assert callback_fn.has_been_called
|
||||||
assert number_of_steps == 2
|
assert number_of_steps == 2
|
||||||
|
@ -508,7 +521,7 @@ class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
|
||||||
pipe.enable_attention_slicing(1)
|
pipe.enable_attention_slicing(1)
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
inputs = self.get_inputs(dtype=torch.float16)
|
||||||
_ = pipe(**inputs)
|
_ = pipe(**inputs)
|
||||||
|
|
||||||
mem_bytes = torch.cuda.max_memory_allocated()
|
mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
|
@ -524,7 +537,7 @@ class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_inputs(self, device, dtype=torch.float32, seed=0):
|
def get_inputs(self, device="cpu", dtype=torch.float32, seed=0):
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
generator = torch.Generator(device=device).manual_seed(seed)
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
|
||||||
|
@ -545,7 +558,7 @@ class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device)
|
inputs = self.get_inputs()
|
||||||
image = pipe(**inputs).images[0]
|
image = pipe(**inputs).images[0]
|
||||||
|
|
||||||
expected_image = load_numpy(
|
expected_image = load_numpy(
|
||||||
|
@ -561,7 +574,7 @@ class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device)
|
inputs = self.get_inputs()
|
||||||
image = pipe(**inputs).images[0]
|
image = pipe(**inputs).images[0]
|
||||||
|
|
||||||
expected_image = load_numpy(
|
expected_image = load_numpy(
|
||||||
|
@ -577,7 +590,7 @@ class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device)
|
inputs = self.get_inputs()
|
||||||
image = pipe(**inputs).images[0]
|
image = pipe(**inputs).images[0]
|
||||||
|
|
||||||
expected_image = load_numpy(
|
expected_image = load_numpy(
|
||||||
|
@ -593,7 +606,7 @@ class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
inputs = self.get_inputs(torch_device)
|
inputs = self.get_inputs()
|
||||||
inputs["num_inference_steps"] = 30
|
inputs["num_inference_steps"] = 30
|
||||||
image = pipe(**inputs).images[0]
|
image = pipe(**inputs).images[0]
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=init_image,
|
image=init_image,
|
||||||
|
@ -196,7 +196,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=init_image,
|
image=init_image,
|
||||||
|
@ -237,7 +237,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
_ = pipe(
|
_ = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=init_image,
|
image=init_image,
|
||||||
|
|
|
@ -241,7 +241,7 @@ class StableDiffusionUpscalePipelineFastTests(unittest.TestCase):
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = sd_pipe(
|
image = sd_pipe(
|
||||||
[prompt],
|
[prompt],
|
||||||
image=low_res_image,
|
image=low_res_image,
|
||||||
|
@ -281,7 +281,7 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "a cat sitting on a park bench"
|
prompt = "a cat sitting on a park bench"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=image,
|
image=image,
|
||||||
|
@ -314,7 +314,7 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "a cat sitting on a park bench"
|
prompt = "a cat sitting on a park bench"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(
|
output = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=image,
|
image=image,
|
||||||
|
@ -348,7 +348,7 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "a cat sitting on a park bench"
|
prompt = "a cat sitting on a park bench"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
_ = pipe(
|
_ = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=image,
|
image=image,
|
||||||
|
|
|
@ -194,6 +194,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.4616, 0.5184, 0.4887, 0.5111, 0.4839, 0.48, 0.5119, 0.5263, 0.4776])
|
expected_slice = np.array([0.4616, 0.5184, 0.4887, 0.5111, 0.4839, 0.48, 0.5119, 0.5263, 0.4776])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
@ -233,7 +234,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
|
@ -255,14 +256,15 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np")
|
output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np")
|
||||||
|
|
||||||
image = output.images
|
image = output.images
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 768, 768, 3)
|
assert image.shape == (1, 768, 768, 3)
|
||||||
expected_slice = np.array([0.0567, 0.057, 0.0416, 0.0463, 0.0433, 0.06, 0.0517, 0.0526, 0.0866])
|
expected_slice = np.array([0.1868, 0.1922, 0.1527, 0.1921, 0.1908, 0.1624, 0.1779, 0.1652, 0.1734])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_v_pred_upcast_attention(self):
|
def test_stable_diffusion_v_pred_upcast_attention(self):
|
||||||
|
@ -274,15 +276,16 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np")
|
output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np")
|
||||||
|
|
||||||
image = output.images
|
image = output.images
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 768, 768, 3)
|
assert image.shape == (1, 768, 768, 3)
|
||||||
expected_slice = np.array([0.0461, 0.0483, 0.0566, 0.0512, 0.0446, 0.0751, 0.0664, 0.0551, 0.0488])
|
expected_slice = np.array([0.4209, 0.4087, 0.4097, 0.4209, 0.3860, 0.4329, 0.4280, 0.4324, 0.4187])
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
|
|
||||||
def test_stable_diffusion_v_pred_euler(self):
|
def test_stable_diffusion_v_pred_euler(self):
|
||||||
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
|
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
|
||||||
|
@ -292,7 +295,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
|
|
||||||
output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy")
|
output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy")
|
||||||
image = output.images
|
image = output.images
|
||||||
|
@ -300,7 +303,8 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 768, 768, 3)
|
assert image.shape == (1, 768, 768, 3)
|
||||||
expected_slice = np.array([0.0351, 0.0376, 0.0505, 0.0424, 0.0551, 0.0656, 0.0471, 0.0276, 0.0596])
|
expected_slice = np.array([0.1781, 0.1695, 0.1661, 0.1705, 0.1588, 0.1699, 0.2005, 0.1589, 0.1677])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_v_pred_dpm(self):
|
def test_stable_diffusion_v_pred_dpm(self):
|
||||||
|
@ -316,14 +320,15 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "a photograph of an astronaut riding a horse"
|
prompt = "a photograph of an astronaut riding a horse"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = sd_pipe(
|
image = sd_pipe(
|
||||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy"
|
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy"
|
||||||
).images
|
).images
|
||||||
|
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
assert image.shape == (1, 768, 768, 3)
|
assert image.shape == (1, 768, 768, 3)
|
||||||
expected_slice = np.array([0.2049, 0.2115, 0.2323, 0.2416, 0.256, 0.2484, 0.2517, 0.2358, 0.236])
|
expected_slice = np.array([0.3303, 0.3184, 0.3291, 0.3300, 0.3256, 0.3113, 0.2965, 0.3134, 0.3192])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_attention_slicing_v_pred(self):
|
def test_stable_diffusion_attention_slicing_v_pred(self):
|
||||||
|
@ -337,12 +342,11 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
# make attention efficient
|
# make attention efficient
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
with torch.autocast(torch_device):
|
output_chunked = pipe(
|
||||||
output_chunked = pipe(
|
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
)
|
||||||
)
|
image_chunked = output_chunked.images
|
||||||
image_chunked = output_chunked.images
|
|
||||||
|
|
||||||
mem_bytes = torch.cuda.max_memory_allocated()
|
mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
@ -351,12 +355,9 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
# disable slicing
|
# disable slicing
|
||||||
pipe.disable_attention_slicing()
|
pipe.disable_attention_slicing()
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
with torch.autocast(torch_device):
|
output = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")
|
||||||
output = pipe(
|
image = output.images
|
||||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
|
||||||
)
|
|
||||||
image = output.images
|
|
||||||
|
|
||||||
# make sure that more than 5.5 GB is allocated
|
# make sure that more than 5.5 GB is allocated
|
||||||
mem_bytes = torch.cuda.max_memory_allocated()
|
mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
|
@ -376,12 +377,12 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "astronaut riding a horse"
|
prompt = "astronaut riding a horse"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
||||||
image = output.images[0]
|
image = output.images[0]
|
||||||
|
|
||||||
assert image.shape == (768, 768, 3)
|
assert image.shape == (768, 768, 3)
|
||||||
assert np.abs(expected_image - image).max() < 5e-3
|
assert np.abs(expected_image - image).max() < 7.5e-2
|
||||||
|
|
||||||
def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self):
|
def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self):
|
||||||
expected_image = load_numpy(
|
expected_image = load_numpy(
|
||||||
|
@ -395,12 +396,12 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "astronaut riding a horse"
|
prompt = "astronaut riding a horse"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
||||||
image = output.images[0]
|
image = output.images[0]
|
||||||
|
|
||||||
assert image.shape == (768, 768, 3)
|
assert image.shape == (768, 768, 3)
|
||||||
assert np.abs(expected_image - image).max() < 5e-1
|
assert np.abs(expected_image - image).max() < 7.5e-1
|
||||||
|
|
||||||
def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
|
def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
|
||||||
number_of_steps = 0
|
number_of_steps = 0
|
||||||
|
@ -413,18 +414,16 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 96, 96)
|
assert latents.shape == (1, 4, 96, 96)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array(
|
expected_slice = np.array([0.7749, 0.0325, 0.5088, 0.1619, 0.3372, 0.3667, -0.5186, 0.6860, 1.4326])
|
||||||
[-0.2543, -1.2755, 0.4261, -0.9555, -1.173, -0.5892, 2.4159, 0.1554, -1.2098]
|
|
||||||
)
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
|
||||||
elif step == 19:
|
elif step == 19:
|
||||||
latents = latents.detach().cpu().numpy()
|
latents = latents.detach().cpu().numpy()
|
||||||
assert latents.shape == (1, 4, 96, 96)
|
assert latents.shape == (1, 4, 96, 96)
|
||||||
latents_slice = latents[0, -3:, -3:, -1]
|
latents_slice = latents[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array(
|
expected_slice = np.array([1.3887, 1.0273, 1.7266, 0.0726, 0.6611, 0.1598, -1.0547, 0.1522, 0.0227])
|
||||||
[-0.9572, -0.967, -0.6152, 0.0894, -0.699, -0.2344, 1.5465, -0.0357, -0.1141]
|
|
||||||
)
|
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
|
|
||||||
test_callback_fn.has_been_called = False
|
test_callback_fn.has_been_called = False
|
||||||
|
|
||||||
|
@ -435,16 +434,15 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "Andromeda galaxy in a bottle"
|
prompt = "Andromeda galaxy in a bottle"
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
with torch.autocast(torch_device):
|
pipe(
|
||||||
pipe(
|
prompt=prompt,
|
||||||
prompt=prompt,
|
num_inference_steps=20,
|
||||||
num_inference_steps=20,
|
guidance_scale=7.5,
|
||||||
guidance_scale=7.5,
|
generator=generator,
|
||||||
generator=generator,
|
callback=test_callback_fn,
|
||||||
callback=test_callback_fn,
|
callback_steps=1,
|
||||||
callback_steps=1,
|
)
|
||||||
)
|
|
||||||
assert test_callback_fn.has_been_called
|
assert test_callback_fn.has_been_called
|
||||||
assert number_of_steps == 20
|
assert number_of_steps == 20
|
||||||
|
|
||||||
|
@ -475,7 +473,7 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipeline.enable_attention_slicing(1)
|
pipeline.enable_attention_slicing(1)
|
||||||
pipeline.enable_sequential_cpu_offload()
|
pipeline.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
_ = pipeline(prompt, generator=generator, num_inference_steps=5)
|
_ = pipeline(prompt, generator=generator, num_inference_steps=5)
|
||||||
|
|
||||||
mem_bytes = torch.cuda.max_memory_allocated()
|
mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
|
|
|
@ -23,7 +23,7 @@ import torch
|
||||||
|
|
||||||
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
|
||||||
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
|
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
|
||||||
from diffusers.utils import floats_tensor, slow, torch_device
|
from diffusers.utils import floats_tensor, nightly, torch_device
|
||||||
from diffusers.utils.testing_utils import require_torch_gpu
|
from diffusers.utils.testing_utils import require_torch_gpu
|
||||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
@ -201,6 +201,7 @@ class SafeDiffusionPipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945])
|
expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
@ -253,13 +254,12 @@ class SafeDiffusionPipelineFastTests(unittest.TestCase):
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images
|
||||||
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
|
||||||
|
|
||||||
assert image.shape == (1, 64, 64, 3)
|
assert image.shape == (1, 64, 64, 3)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@nightly
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
@ -284,7 +284,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
guidance_scale = 7
|
guidance_scale = 7
|
||||||
|
|
||||||
# without safety guidance (sld_guidance_scale = 0)
|
# without safety guidance (sld_guidance_scale = 0)
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
output = sd_pipe(
|
output = sd_pipe(
|
||||||
[prompt],
|
[prompt],
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -301,10 +301,11 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]
|
expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
# without safety guidance (strong configuration)
|
# without safety guidance (strong configuration)
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
output = sd_pipe(
|
output = sd_pipe(
|
||||||
[prompt],
|
[prompt],
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -325,6 +326,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
|
expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_nudity_safe_stable_diffusion(self):
|
def test_nudity_safe_stable_diffusion(self):
|
||||||
|
@ -337,7 +339,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
seed = 2734971755
|
seed = 2734971755
|
||||||
guidance_scale = 7
|
guidance_scale = 7
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
output = sd_pipe(
|
output = sd_pipe(
|
||||||
[prompt],
|
[prompt],
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -354,9 +356,10 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]
|
expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
output = sd_pipe(
|
output = sd_pipe(
|
||||||
[prompt],
|
[prompt],
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -377,6 +380,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]
|
expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_nudity_safetychecker_safe_stable_diffusion(self):
|
def test_nudity_safetychecker_safe_stable_diffusion(self):
|
||||||
|
@ -391,7 +395,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
seed = 1044355234
|
seed = 1044355234
|
||||||
guidance_scale = 12
|
guidance_scale = 12
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
output = sd_pipe(
|
output = sd_pipe(
|
||||||
[prompt],
|
[prompt],
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -408,9 +412,10 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
output = sd_pipe(
|
output = sd_pipe(
|
||||||
[prompt],
|
[prompt],
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -430,4 +435,5 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])
|
expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -460,11 +460,9 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
_ = pipe(
|
_ = pipe(
|
||||||
"horse",
|
"horse",
|
||||||
num_images_per_prompt=1,
|
num_images_per_prompt=1,
|
||||||
generator=generator,
|
|
||||||
prior_num_inference_steps=2,
|
prior_num_inference_steps=2,
|
||||||
decoder_num_inference_steps=2,
|
decoder_num_inference_steps=2,
|
||||||
super_res_num_inference_steps=2,
|
super_res_num_inference_steps=2,
|
||||||
|
|
|
@ -51,7 +51,7 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt="first prompt",
|
prompt="first prompt",
|
||||||
image=second_prompt,
|
image=second_prompt,
|
||||||
|
@ -92,7 +92,7 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||||
second_prompt = load_image(
|
second_prompt = load_image(
|
||||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||||
)
|
)
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=first_prompt,
|
prompt=first_prompt,
|
||||||
image=second_prompt,
|
image=second_prompt,
|
||||||
|
@ -106,5 +106,6 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.014, 0.0112, 0.0136, 0.0145, 0.0107, 0.0113, 0.0272, 0.0215, 0.0216])
|
expected_slice = np.array([0.0787, 0.0849, 0.0826, 0.0812, 0.0807, 0.0795, 0.0818, 0.0798, 0.0779])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -40,7 +40,7 @@ class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase
|
||||||
image_prompt = load_image(
|
image_prompt = load_image(
|
||||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||||
)
|
)
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
image=image_prompt,
|
image=image_prompt,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -52,5 +52,6 @@ class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.1205, 0.1914, 0.2289, 0.0883, 0.1595, 0.1683, 0.0703, 0.1493, 0.1298])
|
expected_slice = np.array([0.0441, 0.0469, 0.0507, 0.0575, 0.0632, 0.0650, 0.0865, 0.0909, 0.0945])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -49,7 +49,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe.dual_guided(
|
image = pipe.dual_guided(
|
||||||
prompt="first prompt",
|
prompt="first prompt",
|
||||||
image=prompt_image,
|
image=prompt_image,
|
||||||
|
@ -88,7 +88,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||||
init_image = load_image(
|
init_image = load_image(
|
||||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||||
)
|
)
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe.dual_guided(
|
image = pipe.dual_guided(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=init_image,
|
image=init_image,
|
||||||
|
@ -102,11 +102,12 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.0081, 0.0032, 0.0002, 0.0056, 0.0027, 0.0000, 0.0051, 0.0020, 0.0007])
|
expected_slice = np.array([0.1448, 0.1619, 0.1741, 0.1086, 0.1147, 0.1128, 0.1199, 0.1165, 0.1001])
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger "
|
prompt = "A painting of a squirrel eating a burger "
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe.text_to_image(
|
image = pipe.text_to_image(
|
||||||
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
||||||
).images
|
).images
|
||||||
|
@ -114,13 +115,15 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222])
|
expected_slice = np.array([0.3367, 0.3169, 0.2656, 0.3870, 0.4790, 0.3796, 0.4009, 0.4878, 0.4778])
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||||
|
|
||||||
image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images
|
image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.3403, 0.1809, 0.0938, 0.3855, 0.2393, 0.1243, 0.4028, 0.3110, 0.1799])
|
expected_slice = np.array([0.3076, 0.3123, 0.3284, 0.3782, 0.3770, 0.3894, 0.4297, 0.4331, 0.4456])
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||||
|
|
|
@ -48,7 +48,7 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger "
|
prompt = "A painting of a squirrel eating a burger "
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy"
|
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy"
|
||||||
).images
|
).images
|
||||||
|
@ -72,7 +72,7 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger "
|
prompt = "A painting of a squirrel eating a burger "
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
||||||
).images
|
).images
|
||||||
|
@ -80,5 +80,6 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||||
image_slice = image[0, 253:256, 253:256, -1]
|
image_slice = image[0, 253:256, 253:256, -1]
|
||||||
|
|
||||||
assert image.shape == (1, 512, 512, 3)
|
assert image.shape == (1, 512, 512, 3)
|
||||||
expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222])
|
expected_slice = np.array([0.3493, 0.3757, 0.4093, 0.4495, 0.4233, 0.4102, 0.4507, 0.4756, 0.4787])
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
|
@ -212,6 +212,8 @@ class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipeline = pipeline.to(torch_device)
|
pipeline = pipeline.to(torch_device)
|
||||||
pipeline.set_progress_bar_config(disable=None)
|
pipeline.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
# requires GPU generator for gumbel softmax
|
||||||
|
# don't use GPU generator in tests though
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
output = pipeline(
|
output = pipeline(
|
||||||
"teddy bear playing in the pool",
|
"teddy bear playing in the pool",
|
||||||
|
|
|
@ -86,19 +86,11 @@ class DownloadTests(unittest.TestCase):
|
||||||
|
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe_2 = pipe_2.to(torch_device)
|
pipe_2 = pipe_2.to(torch_device)
|
||||||
if torch_device == "mps":
|
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||||
|
@ -125,20 +117,12 @@ class DownloadTests(unittest.TestCase):
|
||||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||||
)
|
)
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||||
pipe_2 = pipe_2.to(torch_device)
|
pipe_2 = pipe_2.to(torch_device)
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
assert np.max(np.abs(out - out_2)) < 1e-3
|
assert np.max(np.abs(out - out_2)) < 1e-3
|
||||||
|
@ -149,11 +133,7 @@ class DownloadTests(unittest.TestCase):
|
||||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||||
)
|
)
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
@ -161,11 +141,7 @@ class DownloadTests(unittest.TestCase):
|
||||||
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
|
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
|
||||||
pipe_2 = pipe_2.to(torch_device)
|
pipe_2 = pipe_2.to(torch_device)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
|
@ -175,11 +151,8 @@ class DownloadTests(unittest.TestCase):
|
||||||
prompt = "hello"
|
prompt = "hello"
|
||||||
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
if torch_device == "mps":
|
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
generator = torch.manual_seed(0)
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
@ -187,11 +160,7 @@ class DownloadTests(unittest.TestCase):
|
||||||
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||||
pipe_2 = pipe_2.to(torch_device)
|
pipe_2 = pipe_2.to(torch_device)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||||
|
|
||||||
|
@ -401,12 +370,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
scheduler = scheduler_fn()
|
scheduler = scheduler_fn()
|
||||||
pipeline = pipeline_fn(unet, scheduler).to(torch_device)
|
pipeline = pipeline_fn(unet, scheduler).to(torch_device)
|
||||||
|
|
||||||
# Device type MPS is not supported for torch.Generator() api.
|
generator = torch.manual_seed(0)
|
||||||
if torch_device == "mps":
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
out_image = pipeline(
|
out_image = pipeline(
|
||||||
generator=generator,
|
generator=generator,
|
||||||
num_inference_steps=2,
|
num_inference_steps=2,
|
||||||
|
@ -442,12 +406,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
|
||||||
# Device type MPS is not supported for torch.Generator() api.
|
generator = torch.manual_seed(0)
|
||||||
if torch_device == "mps":
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
image_inpaint = inpaint(
|
image_inpaint = inpaint(
|
||||||
[prompt],
|
[prompt],
|
||||||
generator=generator,
|
generator=generator,
|
||||||
|
@ -798,7 +757,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||||
|
|
||||||
generator = generator.manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||||
|
|
||||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||||
|
@ -819,7 +778,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||||
|
|
||||||
generator = generator.manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
|
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||||
|
|
||||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||||
|
@ -842,7 +801,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
image = ddpm_from_hub_custom_model(generator=generator, num_inference_steps=5, output_type="numpy").images
|
image = ddpm_from_hub_custom_model(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||||
|
|
||||||
generator = generator.manual_seed(0)
|
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||||
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
|
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||||
|
|
||||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||||
|
@ -855,18 +814,17 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
images = pipe(output_type="numpy").images
|
||||||
images = pipe(generator=generator, output_type="numpy").images
|
|
||||||
assert images.shape == (1, 32, 32, 3)
|
assert images.shape == (1, 32, 32, 3)
|
||||||
assert isinstance(images, np.ndarray)
|
assert isinstance(images, np.ndarray)
|
||||||
|
|
||||||
images = pipe(generator=generator, output_type="pil", num_inference_steps=4).images
|
images = pipe(output_type="pil", num_inference_steps=4).images
|
||||||
assert isinstance(images, list)
|
assert isinstance(images, list)
|
||||||
assert len(images) == 1
|
assert len(images) == 1
|
||||||
assert isinstance(images[0], PIL.Image.Image)
|
assert isinstance(images[0], PIL.Image.Image)
|
||||||
|
|
||||||
# use PIL by default
|
# use PIL by default
|
||||||
images = pipe(generator=generator, num_inference_steps=4).images
|
images = pipe(num_inference_steps=4).images
|
||||||
assert isinstance(images, list)
|
assert isinstance(images, list)
|
||||||
assert isinstance(images[0], PIL.Image.Image)
|
assert isinstance(images[0], PIL.Image.Image)
|
||||||
|
|
||||||
|
|
|
@ -288,11 +288,11 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
|
|
||||||
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||||
|
@ -330,11 +330,11 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
kwargs["num_inference_steps"] = num_inference_steps
|
kwargs["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||||
|
@ -372,11 +372,11 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
kwargs["num_inference_steps"] = num_inference_steps
|
kwargs["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample
|
output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample
|
new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||||
|
@ -510,7 +510,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
|
|
||||||
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
|
outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
|
||||||
|
|
||||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||||
|
@ -520,7 +520,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
|
|
||||||
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
|
outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
|
||||||
|
|
||||||
recursive_check(outputs_tuple, outputs_dict)
|
recursive_check(outputs_tuple, outputs_dict)
|
||||||
|
@ -664,12 +664,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
|
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
kwargs["generator"] = torch.Generator().manual_seed(0)
|
kwargs["generator"] = torch.manual_seed(0)
|
||||||
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
|
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
|
||||||
|
|
||||||
assert (output - output_eps).abs().sum() < 1e-5
|
assert (output - output_eps).abs().sum() < 1e-5
|
||||||
|
@ -1822,11 +1822,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
|
|
||||||
scheduler.set_timesteps(self.num_inference_steps)
|
scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
model = self.dummy_model()
|
model = self.dummy_model()
|
||||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
@ -1853,11 +1849,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
|
|
||||||
scheduler.set_timesteps(self.num_inference_steps)
|
scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
model = self.dummy_model()
|
model = self.dummy_model()
|
||||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
@ -1884,11 +1876,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
|
|
||||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
model = self.dummy_model()
|
model = self.dummy_model()
|
||||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
@ -1947,11 +1935,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
|
|
||||||
scheduler.set_timesteps(self.num_inference_steps)
|
scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
model = self.dummy_model()
|
model = self.dummy_model()
|
||||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
@ -1968,13 +1952,8 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
result_sum = torch.sum(torch.abs(sample))
|
result_sum = torch.sum(torch.abs(sample))
|
||||||
result_mean = torch.mean(torch.abs(sample))
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
|
||||||
if torch_device in ["cpu", "mps"]:
|
assert abs(result_sum.item() - 152.3192) < 1e-2
|
||||||
assert abs(result_sum.item() - 152.3192) < 1e-2
|
assert abs(result_mean.item() - 0.1983) < 1e-3
|
||||||
assert abs(result_mean.item() - 0.1983) < 1e-3
|
|
||||||
else:
|
|
||||||
# CUDA
|
|
||||||
assert abs(result_sum.item() - 144.8084) < 1e-2
|
|
||||||
assert abs(result_mean.item() - 0.18855) < 1e-3
|
|
||||||
|
|
||||||
def test_full_loop_with_v_prediction(self):
|
def test_full_loop_with_v_prediction(self):
|
||||||
scheduler_class = self.scheduler_classes[0]
|
scheduler_class = self.scheduler_classes[0]
|
||||||
|
@ -1983,11 +1962,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
|
|
||||||
scheduler.set_timesteps(self.num_inference_steps)
|
scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
model = self.dummy_model()
|
model = self.dummy_model()
|
||||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
@ -2004,13 +1979,8 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
result_sum = torch.sum(torch.abs(sample))
|
result_sum = torch.sum(torch.abs(sample))
|
||||||
result_mean = torch.mean(torch.abs(sample))
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
|
||||||
if torch_device in ["cpu", "mps"]:
|
assert abs(result_sum.item() - 108.4439) < 1e-2
|
||||||
assert abs(result_sum.item() - 108.4439) < 1e-2
|
assert abs(result_mean.item() - 0.1412) < 1e-3
|
||||||
assert abs(result_mean.item() - 0.1412) < 1e-3
|
|
||||||
else:
|
|
||||||
# CUDA
|
|
||||||
assert abs(result_sum.item() - 102.5807) < 1e-2
|
|
||||||
assert abs(result_mean.item() - 0.1335) < 1e-3
|
|
||||||
|
|
||||||
def test_full_loop_device(self):
|
def test_full_loop_device(self):
|
||||||
scheduler_class = self.scheduler_classes[0]
|
scheduler_class = self.scheduler_classes[0]
|
||||||
|
@ -2018,12 +1988,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
scheduler = scheduler_class(**scheduler_config)
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
|
||||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
if torch_device == "mps":
|
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
model = self.dummy_model()
|
model = self.dummy_model()
|
||||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
@ -2040,17 +2005,8 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
result_sum = torch.sum(torch.abs(sample))
|
result_sum = torch.sum(torch.abs(sample))
|
||||||
result_mean = torch.mean(torch.abs(sample))
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
|
||||||
if str(torch_device).startswith("cpu"):
|
assert abs(result_sum.item() - 152.3192) < 1e-2
|
||||||
# The following sum varies between 148 and 156 on mps. Why?
|
assert abs(result_mean.item() - 0.1983) < 1e-3
|
||||||
assert abs(result_sum.item() - 152.3192) < 1e-2
|
|
||||||
assert abs(result_mean.item() - 0.1983) < 1e-3
|
|
||||||
elif str(torch_device).startswith("mps"):
|
|
||||||
# Larger tolerance on mps
|
|
||||||
assert abs(result_mean.item() - 0.1983) < 1e-2
|
|
||||||
else:
|
|
||||||
# CUDA
|
|
||||||
assert abs(result_sum.item() - 144.8084) < 1e-2
|
|
||||||
assert abs(result_mean.item() - 0.18855) < 1e-3
|
|
||||||
|
|
||||||
|
|
||||||
class IPNDMSchedulerTest(SchedulerCommonTest):
|
class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||||
|
@ -2745,7 +2701,7 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
|
|
||||||
scheduler.set_timesteps(self.num_inference_steps)
|
scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
|
|
||||||
model = self.dummy_model()
|
model = self.dummy_model()
|
||||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
@ -2762,13 +2718,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
result_sum = torch.sum(torch.abs(sample))
|
result_sum = torch.sum(torch.abs(sample))
|
||||||
result_mean = torch.mean(torch.abs(sample))
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
|
||||||
if torch_device in ["cpu", "mps"]:
|
assert abs(result_sum.item() - 13849.3877) < 1e-2
|
||||||
assert abs(result_sum.item() - 13849.3945) < 1e-2
|
assert abs(result_mean.item() - 18.0331) < 5e-3
|
||||||
assert abs(result_mean.item() - 18.0331) < 5e-3
|
|
||||||
else:
|
|
||||||
# CUDA
|
|
||||||
assert abs(result_sum.item() - 13913.0449) < 1e-2
|
|
||||||
assert abs(result_mean.item() - 18.1159) < 5e-3
|
|
||||||
|
|
||||||
def test_prediction_type(self):
|
def test_prediction_type(self):
|
||||||
for prediction_type in ["epsilon", "v_prediction"]:
|
for prediction_type in ["epsilon", "v_prediction"]:
|
||||||
|
@ -2787,11 +2738,7 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
sample = sample.to(torch_device)
|
sample = sample.to(torch_device)
|
||||||
|
|
||||||
if torch_device == "mps":
|
generator = torch.manual_seed(0)
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
for i, t in enumerate(scheduler.timesteps):
|
for i, t in enumerate(scheduler.timesteps):
|
||||||
sample = scheduler.scale_model_input(sample, t)
|
sample = scheduler.scale_model_input(sample, t)
|
||||||
|
@ -2804,13 +2751,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
result_sum = torch.sum(torch.abs(sample))
|
result_sum = torch.sum(torch.abs(sample))
|
||||||
result_mean = torch.mean(torch.abs(sample))
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
|
||||||
if torch_device in ["cpu", "mps"]:
|
assert abs(result_sum.item() - 328.9970) < 1e-2
|
||||||
assert abs(result_sum.item() - 328.9970) < 1e-2
|
assert abs(result_mean.item() - 0.4284) < 1e-3
|
||||||
assert abs(result_mean.item() - 0.4284) < 1e-3
|
|
||||||
else:
|
|
||||||
# CUDA
|
|
||||||
assert abs(result_sum.item() - 327.8027) < 1e-2
|
|
||||||
assert abs(result_mean.item() - 0.4268) < 1e-3
|
|
||||||
|
|
||||||
def test_full_loop_device(self):
|
def test_full_loop_device(self):
|
||||||
if torch_device == "mps":
|
if torch_device == "mps":
|
||||||
|
@ -2820,12 +2762,7 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
scheduler = scheduler_class(**scheduler_config)
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
|
||||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
if torch_device == "mps":
|
|
||||||
# device type MPS is not supported for torch.Generator() api.
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
|
||||||
|
|
||||||
model = self.dummy_model()
|
model = self.dummy_model()
|
||||||
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
|
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
|
||||||
|
@ -2841,13 +2778,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
result_sum = torch.sum(torch.abs(sample))
|
result_sum = torch.sum(torch.abs(sample))
|
||||||
result_mean = torch.mean(torch.abs(sample))
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
|
||||||
if str(torch_device).startswith("cpu"):
|
assert abs(result_sum.item() - 13849.3818) < 1e-1
|
||||||
assert abs(result_sum.item() - 13849.3945) < 1e-2
|
assert abs(result_mean.item() - 18.0331) < 1e-3
|
||||||
assert abs(result_mean.item() - 18.0331) < 5e-3
|
|
||||||
else:
|
|
||||||
# CUDA
|
|
||||||
assert abs(result_sum.item() - 13913.0332) < 1e-1
|
|
||||||
assert abs(result_mean.item() - 18.1159) < 1e-3
|
|
||||||
|
|
||||||
|
|
||||||
# UnCLIPScheduler is a modified DDPMScheduler with a subset of the configuration.
|
# UnCLIPScheduler is a modified DDPMScheduler with a subset of the configuration.
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
def overwrite_file(file, class_name, test_name, correct_line, done_test):
|
||||||
|
done_test[file] += 1
|
||||||
|
|
||||||
|
with open(file, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
class_regex = f"class {class_name}("
|
||||||
|
test_regex = f"{4 * ' '}def {test_name}("
|
||||||
|
line_begin_regex = f"{8 * ' '}{correct_line.split()[0]}"
|
||||||
|
another_line_begin_regex = f"{16 * ' '}{correct_line.split()[0]}"
|
||||||
|
in_class = False
|
||||||
|
in_func = False
|
||||||
|
in_line = False
|
||||||
|
insert_line = False
|
||||||
|
count = 0
|
||||||
|
spaces = 0
|
||||||
|
|
||||||
|
new_lines = []
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith(class_regex):
|
||||||
|
in_class = True
|
||||||
|
elif in_class and line.startswith(test_regex):
|
||||||
|
in_func = True
|
||||||
|
elif in_class and in_func and (line.startswith(line_begin_regex) or line.startswith(another_line_begin_regex)):
|
||||||
|
spaces = len(line.split(correct_line.split()[0])[0])
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count == done_test[file]:
|
||||||
|
in_line = True
|
||||||
|
|
||||||
|
if in_class and in_func and in_line:
|
||||||
|
if ")" not in line:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
insert_line = True
|
||||||
|
|
||||||
|
if in_class and in_func and in_line and insert_line:
|
||||||
|
new_lines.append(f"{spaces * ' '}{correct_line}")
|
||||||
|
in_class = in_func = in_line = insert_line = False
|
||||||
|
else:
|
||||||
|
new_lines.append(line)
|
||||||
|
|
||||||
|
with open(file, "w") as f:
|
||||||
|
for line in new_lines:
|
||||||
|
f.write(line)
|
||||||
|
|
||||||
|
|
||||||
|
def main(correct, fail=None):
|
||||||
|
if fail is not None:
|
||||||
|
with open(fail, "r") as f:
|
||||||
|
test_failures = set([l.strip() for l in f.readlines()])
|
||||||
|
else:
|
||||||
|
test_failures = None
|
||||||
|
|
||||||
|
with open(correct, "r") as f:
|
||||||
|
correct_lines = f.readlines()
|
||||||
|
|
||||||
|
done_tests = defaultdict(int)
|
||||||
|
for line in correct_lines:
|
||||||
|
file, class_name, test_name, correct_line = line.split(";")
|
||||||
|
if test_failures is None or "::".join([file, class_name, test_name]) in test_failures:
|
||||||
|
overwrite_file(file, class_name, test_name, correct_line, done_tests)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--correct_filename", help="filename of tests with expected result")
|
||||||
|
parser.add_argument("--fail_filename", help="filename of test failures", type=str, default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args.correct_filename, args.fail_filename)
|
Loading…
Reference in New Issue