Nightly integration tests (#1664)

* [WIP] Nightly integration tests

* initial SD tests

* update SD slow tests

* style

* repaint

* ImageVariations

* style

* finish imgvar

* img2img tests

* debug

* inpaint 1.5

* inpaint legacy

* torch isn't happy about deterministic ops

* allclose -> max diff for shorter logs

* add SD2

* debug

* Update tests/pipelines/stable_diffusion_2/test_stable_diffusion.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/pipelines/stable_diffusion/test_stable_diffusion.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix refs

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* fix refs

* remove debug

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Anton Lozhkov 2022-12-16 18:51:11 +01:00 committed by GitHub
parent acd317810b
commit 086c7f9ea8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1333 additions and 941 deletions

View File

@ -1,4 +1,4 @@
name: Nightly integration tests
name: Nightly tests on main
on:
schedule:
@ -9,12 +9,108 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 1000
PYTEST_TIMEOUT: 600
RUN_SLOW: yes
RUN_NIGHTLY: yes
jobs:
run_slow_tests_apple_m1:
name: Slow PyTorch MPS tests on MacOS
run_nightly_tests:
strategy:
fail-fast: false
matrix:
config:
- name: Nightly PyTorch CUDA tests on Ubuntu
framework: pytorch
runner: docker-gpu
image: diffusers/diffusers-pytorch-cuda
report: torch_cuda
- name: Nightly Flax TPU tests on Ubuntu
framework: flax
runner: docker-tpu
image: diffusers/diffusers-flax-tpu
report: flax_tpu
- name: Nightly ONNXRuntime CUDA tests on Ubuntu
framework: onnxruntime
runner: docker-gpu
image: diffusers/diffusers-onnxruntime-cuda
report: onnx_cuda
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}}
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
if: ${{ matrix.config.runner == 'docker-gpu' }}
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
python utils/print_env.py
- name: Run nightly PyTorch CUDA tests
if: ${{ matrix.config.framework == 'pytorch' }}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run nightly Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run nightly ONNXRuntime CUDA tests
if: ${{ matrix.config.framework == 'onnxruntime' }}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: ${{ matrix.config.report }}_test_reports
path: reports
run_nightly_tests_apple_m1:
name: Nightly PyTorch MPS tests on MacOS
runs-on: [ self-hosted, apple-m1 ]
steps:
@ -46,7 +142,7 @@ jobs:
run: |
${CONDA_RUN} python utils/print_env.py
- name: Run slow PyTorch tests on M1 (MPS)
- name: Run nightly PyTorch tests on M1 (MPS)
shell: arch -arch arm64 bash {0}
env:
HF_HOME: /System/Volumes/Data/mnt/cache

View File

@ -1,4 +1,4 @@
name: Run fast tests
name: Fast tests for PRs
on:
pull_request:

View File

@ -1,4 +1,4 @@
name: Run all tests
name: Slow tests on main
on:
push:
@ -10,7 +10,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 1000
PYTEST_TIMEOUT: 600
RUN_SLOW: yes
jobs:

View File

@ -55,6 +55,7 @@ if is_torch_available():
load_hf_numpy,
load_image,
load_numpy,
nightly,
parse_flag_from_env,
require_torch_gpu,
slow,

View File

@ -83,6 +83,7 @@ def parse_flag_from_env(key, default=False):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
@ -111,6 +112,16 @@ def slow(test_case):
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def nightly(test_case):
"""
Decorator marking a test that runs nightly in the diffusers CI.
Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.

View File

@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import load_image, load_numpy, nightly, require_torch_gpu, torch_device
from ...test_pipelines_common import PipelineTesterMixin
@ -83,9 +84,14 @@ class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow
@nightly
@require_torch_gpu
class RepaintPipelineIntegrationTests(unittest.TestCase):
class RepaintPipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_celebahq(self):
original_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
@ -104,6 +110,8 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
scheduler = RePaintScheduler.from_pretrained(model_id)
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)
repaint.set_progress_bar_config(disable=None)
repaint.enable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = repaint(

View File

@ -27,7 +27,7 @@ from diffusers import (
OnnxStableDiffusionPipeline,
PNDMScheduler,
)
from diffusers.utils.testing_utils import is_onnx_available, require_onnxruntime, require_torch_gpu, slow
from diffusers.utils.testing_utils import is_onnx_available, nightly, require_onnxruntime, require_torch_gpu
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
@ -128,7 +128,7 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@nightly
@require_onnxruntime
@require_torch_gpu
class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):

View File

@ -27,7 +27,13 @@ from diffusers import (
PNDMScheduler,
)
from diffusers.utils import floats_tensor
from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow
from diffusers.utils.testing_utils import (
is_onnx_available,
load_image,
nightly,
require_onnxruntime,
require_torch_gpu,
)
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
@ -134,7 +140,7 @@ class OnnxStableDiffusionImg2ImgPipelineFastTests(OnnxPipelineTesterMixin, unitt
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
@slow
@nightly
@require_onnxruntime
@require_torch_gpu
class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):

View File

@ -18,7 +18,13 @@ import unittest
import numpy as np
from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionInpaintPipeline
from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow
from diffusers.utils.testing_utils import (
is_onnx_available,
load_image,
nightly,
require_onnxruntime,
require_torch_gpu,
)
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
@ -32,7 +38,7 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
pass
@slow
@nightly
@require_onnxruntime
@require_torch_gpu
class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):

View File

@ -22,9 +22,9 @@ from diffusers.utils.testing_utils import (
is_onnx_available,
load_image,
load_numpy,
nightly,
require_onnxruntime,
require_torch_gpu,
slow,
)
@ -32,7 +32,7 @@ if is_onnx_available():
import onnxruntime as ort
@slow
@nightly
@require_onnxruntime
@require_torch_gpu
class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase):

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import time
@ -24,6 +25,7 @@ import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
@ -32,7 +34,7 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -435,127 +437,137 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow
@require_torch_gpu
class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
class StableDiffusionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
def test_stable_diffusion_1_1_pndm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.43625, 0.43554, 0.36670, 0.40660, 0.39703, 0.38658, 0.43936, 0.43557, 0.40592])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_stable_diffusion_fast_ddim(self):
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
def test_stable_diffusion_1_4_pndm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
image = output.images
image_slice = image[0, -3:, -3:, -1]
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9326, 0.923, 0.951, 0.9365, 0.9214, 0.951, 0.9365, 0.9414, 0.918])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.57400, 0.47841, 0.31625, 0.63583, 0.58306, 0.55056, 0.50825, 0.56306, 0.55748])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe.scheduler = scheduler
def test_stable_diffusion_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
).images
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_stable_diffusion_memory_chunking(self):
def test_stable_diffusion_lms(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.10542, 0.09620, 0.07332, 0.09015, 0.09382, 0.07597, 0.08496, 0.07806, 0.06455])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_stable_diffusion_dpm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.03503, 0.03494, 0.01087, 0.03128, 0.02552, 0.00803, 0.00742, 0.00372, 0.00000])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_stable_diffusion_attention_slicing(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
# make attention efficient
# enable attention slicing
pipe.enable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images
inputs = self.get_inputs(torch_device, dtype=torch.float16)
image_sliced = pipe(**inputs).images
mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 3.75 GB is allocated
assert mem_bytes < 3.75 * 10**9
# disable chunking
# disable slicing
pipe.disable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images
inputs = self.get_inputs(torch_device, dtype=torch.float16)
image = pipe(**inputs).images
# make sure that more than 3.75 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
assert np.abs(image_sliced - image).max() < 1e-3
def test_stable_diffusion_vae_slicing(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "a photograph of an astronaut riding a horse"
# enable vae slicing
pipe.enable_vae_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images
inputs = self.get_inputs(torch_device, dtype=torch.float16)
inputs["prompt"] = [inputs["prompt"]] * 4
inputs["latents"] = torch.cat([inputs["latents"]] * 4)
image_sliced = pipe(**inputs).images
mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
@ -564,92 +576,58 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
# disable vae slicing
pipe.disable_vae_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images
inputs = self.get_inputs(torch_device, dtype=torch.float16)
inputs["prompt"] = [inputs["prompt"]] * 4
inputs["latents"] = torch.cat([inputs["latents"]] * 4)
image = pipe(**inputs).images
# make sure that more than 4 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 4e9
# There is a small discrepancy at the image borders vs. a fully batched version.
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3
assert np.abs(image_sliced - image).max() < 4e-3
def test_stable_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
def test_stable_diffusion_fp16_vs_autocast(self):
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
inputs = self.get_inputs(torch_device, dtype=torch.float16)
image_fp16 = pipe(**inputs).images
generator = torch.Generator(device=torch_device).manual_seed(0)
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images
inputs = self.get_inputs(torch_device)
image_autocast = pipe(**inputs).images
# Make sure results are close enough
diff = np.abs(image_chunked.flatten() - image.flatten())
diff = np.abs(image_fp16.flatten() - image_autocast.flatten())
# They ARE different since ops are not run always at the same precision
# however, they should be extremely close.
assert diff.mean() < 2e-2
def test_stable_diffusion_text2img_pipeline_default(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text2img/astronaut_riding_a_horse.npy"
)
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 5e-3
def test_stable_diffusion_text2img_intermediate_state(self):
def test_stable_diffusion_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
if step == 1:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
)
expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
elif step == 50:
elif step == 2:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601]
)
expected_slice = np.array([-0.1885, -0.3022, -1.012, -0.514, -0.477, 0.6143, -0.9336, 0.6553, 1.453])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
test_callback_fn.has_been_called = False
callback_fn.has_been_called = False
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
@ -658,20 +636,10 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Andromeda galaxy in a bottle"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 50
inputs = self.get_inputs(torch_device, dtype=torch.float16)
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == inputs["num_inference_steps"]
def test_stable_diffusion_low_cpu_mem_usage(self):
pipeline_id = "CompVis/stable-diffusion-v1-4"
@ -685,7 +653,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
start_time = time.time()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
pipeline_id, revision="fp16", torch_dtype=torch.float16, low_cpu_mem_usage=False
)
normal_load_time = time.time() - start_time
@ -696,17 +664,129 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle"
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
pipeline = pipeline.to(torch_device)
pipeline.enable_attention_slicing(1)
pipeline.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipeline(prompt, generator=generator, num_inference_steps=5)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.8 GB is allocated
assert mem_bytes < 2.8 * 10**9
@nightly
@require_torch_gpu
class StableDiffusionPipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"latents": latents,
"generator": generator,
"num_inference_steps": 50,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
def test_stable_diffusion_1_4_pndm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_text2img/stable_diffusion_1_4_pndm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_1_5_pndm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_text2img/stable_diffusion_1_5_pndm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_text2img/stable_diffusion_1_4_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_lms(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_text2img/stable_diffusion_1_4_lms.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_euler(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
sd_pipe.scheduler = EulerDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_text2img/stable_diffusion_1_4_euler.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_dpm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device)
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_text2img/stable_diffusion_1_4_dpm_multi.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3

View File

@ -22,12 +22,12 @@ import torch
from diffusers import (
AutoencoderKL,
LMSDiscreteScheduler,
DPMSolverMultistepScheduler,
PNDMScheduler,
StableDiffusionImageVariationPipeline,
UNet2DConditionModel,
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModelWithProjection
@ -177,123 +177,157 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
@slow
@require_torch_gpu
class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_img_variation_pipeline_default(self):
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.jpg"
)
init_image = init_image.resize((512, 512))
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.npy"
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_imgvar/input_image_vermeer.png"
)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"image": init_image,
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
model_id = "fusing/sd-image-variations-diffusers"
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
model_id,
safety_checker=None,
def test_stable_diffusion_img_variation_pipeline_default(self):
sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
"lambdalabs/sd-image-variations-diffusers", safety_checker=None
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
init_image,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (512, 512, 3)
# img2img is flaky across GPUs even in fp32, so using MAE here
assert np.abs(expected_image - image).max() < 1e-3
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.84491, 0.90789, 0.75708, 0.78734, 0.83485, 0.70099, 0.66938, 0.68727, 0.61379])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_stable_diffusion_img_variation_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
if step == 1:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953])
expected_slice = np.array([-0.1572, 0.2837, -0.798, -0.1201, -1.304, 0.7754, -2.12, 0.0443, 1.627])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
elif step == 37:
elif step == 2:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([2.285, 2.703, 1.969, 0.696, -1.323, 0.9253, -0.5464, -1.521, -2.537])
expected_slice = np.array([0.6143, 1.734, 1.158, -2.145, -1.926, 0.748, -0.7246, 0.994, 1.539])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
test_callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((512, 512))
callback_fn.has_been_called = False
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
"fusing/sd-image-variations-diffusers",
safety_checker=None,
torch_dtype=torch.float16,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
init_image,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 50
inputs = self.get_inputs(torch_device, dtype=torch.float16)
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == inputs["num_inference_steps"]
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((512, 512))
model_id = "fusing/sd-image-variations-diffusers"
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
model_id, scheduler=lms, safety_checker=None, torch_dtype=torch.float16
model_id, safety_checker=None, torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe(
init_image,
guidance_scale=7.5,
generator=generator,
output_type="np",
num_inference_steps=5,
)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.6 GB is allocated
assert mem_bytes < 2.6 * 10**9
@nightly
@require_torch_gpu
class StableDiffusionImageVariationPipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_imgvar/input_image_vermeer.png"
)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"image": init_image,
"latents": latents,
"generator": generator,
"num_inference_steps": 50,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
def test_img_variation_pndm(self):
sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained("fusing/sd-image-variations-diffusers")
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_imgvar/lambdalabs_variations_pndm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_img_variation_dpm(self):
sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained("fusing/sd-image-variations-diffusers")
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_imgvar/lambdalabs_variations_dpm_multi.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3

View File

@ -23,12 +23,13 @@ import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -212,211 +213,213 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
@slow
@require_torch_gpu
class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_img2img_pipeline_default(self):
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.npy"
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"image": init_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
safety_checker=None,
)
def test_stable_diffusion_img2img_default(self):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (1, 512, 768, 3)
expected_slice = np.array([0.27150, 0.14849, 0.15605, 0.26740, 0.16954, 0.18204, 0.31470, 0.26311, 0.24525])
assert np.abs(expected_slice - image_slice).max() < 1e-3
assert image.shape == (512, 768, 3)
# img2img is flaky across GPUs even in fp32, so using MAE here
assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_img2img_pipeline_k_lms(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_k_lms.npy"
)
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=lms,
safety_checker=None,
)
def test_stable_diffusion_img2img_k_lms(self):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (1, 512, 768, 3)
expected_slice = np.array([0.04890, 0.04862, 0.06422, 0.04655, 0.05108, 0.05307, 0.05926, 0.08759, 0.06852])
assert np.abs(expected_slice - image_slice).max() < 1e-3
assert image.shape == (512, 768, 3)
assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_img2img_pipeline_ddim(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_ddim.npy"
)
model_id = "CompVis/stable-diffusion-v1-4"
ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=ddim,
safety_checker=None,
)
def test_stable_diffusion_img2img_ddim(self):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 768, 3)
assert np.abs(expected_image - image).max() < 1e-3
assert image.shape == (1, 512, 768, 3)
expected_slice = np.array([0.06069, 0.05703, 0.08054, 0.05797, 0.06286, 0.06234, 0.08438, 0.11151, 0.08068])
assert np.abs(expected_slice - image_slice).max() < 1e-3
def test_stable_diffusion_img2img_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
if step == 1:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 96)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
expected_slice = np.array([0.7705, 0.1045, 0.5, 3.393, 3.723, 4.273, 2.467, 3.486, 1.758])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
elif step == 37:
elif step == 2:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 96)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.765, 0.1047, 0.4973, 3.375, 3.709, 4.258, 2.451, 3.46, 1.755])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
callback_fn.has_been_called = False
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
"CompVis/stable-diffusion-v1-4", safety_checker=None, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
image=init_image,
strength=0.75,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 37
inputs = self.get_inputs(torch_device, dtype=torch.float16)
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == 2
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16
"CompVis/stable-diffusion-v1-4", safety_checker=None, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
num_inference_steps=5,
)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9
@nightly
@require_torch_gpu
class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"image": init_image,
"generator": generator,
"num_inference_steps": 50,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
def test_img2img_pndm(self):
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/stable_diffusion_1_5_pndm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_img2img_ddim(self):
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/stable_diffusion_1_5_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_img2img_lms(self):
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/stable_diffusion_1_5_lms.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_img2img_dpm(self):
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 30
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/stable_diffusion_1_5_dpm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3

View File

@ -22,13 +22,14 @@ import torch
from diffusers import (
AutoencoderKL,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from PIL import Image
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -163,210 +164,217 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
@slow
@require_torch_gpu
class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_inpaint_pipeline(self):
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_image.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
"/yellow_cat_sitting_on_a_park_bench.npy"
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
)
inputs = {
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
"image": init_image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
model_id = "runwayml/stable-diffusion-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_inpaint_pipeline_fp16(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
"/yellow_cat_sitting_on_a_park_bench_fp16.npy"
)
model_id = "runwayml/stable-diffusion-inpainting"
def test_stable_diffusion_inpaint_ddim(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
revision="fp16",
torch_dtype=torch.float16,
safety_checker=None,
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.05978, 0.10983, 0.10514, 0.07922, 0.08483, 0.08587, 0.05302, 0.03218, 0.01636])
assert np.abs(expected_slice - image_slice).max() < 1e-4
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 5e-1
def test_stable_diffusion_inpaint_pipeline_pndm(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
def test_stable_diffusion_inpaint_fp16(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, safety_checker=None
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
"/yellow_cat_sitting_on_a_park_bench_pndm.npy"
)
model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
inputs = self.get_inputs(torch_device, dtype=torch.float16)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.06152, 0.11060, 0.10449, 0.07959, 0.08643, 0.08496, 0.05420, 0.03247, 0.01831])
assert np.abs(expected_slice - image_slice).max() < 1e-2
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
def test_stable_diffusion_inpaint_pipeline_k_lms(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
def test_stable_diffusion_inpaint_pndm(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
"/yellow_cat_sitting_on_a_park_bench_k_lms.npy"
)
model_id = "runwayml/stable-diffusion-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
# switch to LMS
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.06892, 0.06994, 0.07905, 0.05366, 0.04709, 0.04890, 0.04107, 0.05083, 0.04180])
assert np.abs(expected_slice - image_slice).max() < 1e-4
def test_stable_diffusion_inpaint_k_lms(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None
)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.23513, 0.22413, 0.29442, 0.24243, 0.26214, 0.30329, 0.26431, 0.25025, 0.25197])
assert np.abs(expected_slice - image_slice).max() < 1e-4
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=None,
scheduler=pndm,
device_map="auto",
revision="fp16",
torch_dtype=torch.float16,
"runwayml/stable-diffusion-inpainting", safety_checker=None, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
generator=generator,
num_inference_steps=5,
output_type="np",
)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9
@nightly
@require_torch_gpu
class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_image.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
)
inputs = {
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
"image": init_image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 50,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
def test_inpaint_ddim(self):
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/stable_diffusion_inpaint_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_inpaint_pndm(self):
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/stable_diffusion_inpaint_pndm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_inpaint_lms(self):
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/stable_diffusion_inpaint_lms.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_inpaint_dpm(self):
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 30
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/stable_diffusion_inpaint_dpm_multi.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
def test_pil_inputs(self):
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)

View File

@ -22,15 +22,16 @@ import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
UNet2DConditionModel,
UNet2DModel,
VQModel,
)
from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils import floats_tensor, load_image, nightly, slow, torch_device
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -340,146 +341,192 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
@slow
@require_torch_gpu
class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_inpaint_legacy_pipeline(self):
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_image.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
"/red_cat_sitting_on_a_park_bench.npy"
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
)
inputs = {
"prompt": "A red cat sitting on a park bench",
"image": init_image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A red cat sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_inpaint_legacy_pipeline_k_lms(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint"
"/red_cat_sitting_on_a_park_bench_k_lms.npy"
)
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
scheduler=lms,
safety_checker=None,
def test_stable_diffusion_inpaint_legacy_pndm(self):
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A red cat sitting on a park bench"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.27200, 0.29103, 0.34405, 0.21418, 0.26317, 0.34281, 0.18033, 0.24911, 0.32028])
assert np.abs(expected_slice - image_slice).max() < 1e-4
def test_stable_diffusion_inpaint_legacy_k_lms(self):
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None
)
image = output.images[0]
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-3
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.29014, 0.28882, 0.32835, 0.26502, 0.28182, 0.31162, 0.29297, 0.29534, 0.28214])
assert np.abs(expected_slice - image_slice).max() < 1e-4
def test_stable_diffusion_inpaint_legacy_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
if step == 1:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
)
expected_slice = np.array([-0.103, 1.415, -0.02197, -0.5107, -0.5903, 0.1953, 0.75, 0.3477, -1.356])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
elif step == 37:
elif step == 2:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659])
expected_slice = np.array([0.4802, 1.154, 0.628, 0.2319, 0.2593, -0.1455, 0.7075, -0.1617, -0.5615])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, revision="fp16", torch_dtype=torch.float16
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A red cat sitting on a park bench"
inputs = self.get_inputs(torch_device, dtype=torch.float16)
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == 2
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
strength=0.75,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 37
@nightly
@require_torch_gpu
class StableDiffusionInpaintLegacyPipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_image.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
)
inputs = {
"prompt": "A red cat sitting on a park bench",
"image": init_image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 50,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
def test_inpaint_pndm(self):
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_pndm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_inpaint_ddim(self):
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_inpaint_lms(self):
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_lms.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_inpaint_dpm(self):
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 30
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_dpm_multi.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3

View File

@ -22,6 +22,7 @@ import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
@ -30,7 +31,7 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -239,170 +240,116 @@ class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow
@require_torch_gpu
class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")
def test_stable_diffusion_default_ddim(self):
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
image = output.images
image_slice = image[0, 253:256, 253:256, -1]
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0788, 0.0823, 0.1091, 0.1165, 0.1263, 0.1459, 0.1317, 0.1507, 0.1551])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_stable_diffusion_ddim(self):
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=scheduler)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
def test_stable_diffusion_pndm(self):
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy")
image = output.images
image_slice = image[0, 253:256, 253:256, -1]
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0642, 0.0382, 0.0408, 0.0395, 0.0227, 0.0942, 0.0749, 0.0669, 0.0248])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_stable_diffusion_k_lms(self):
scheduler = LMSDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=scheduler)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy"
).images
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
image_slice = image[0, 253:256, 253:256, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.0548, 0.0626, 0.0612, 0.0611, 0.0706, 0.0586, 0.0843, 0.0333, 0.1197])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([0.10440, 0.13115, 0.11100, 0.10141, 0.11440, 0.07215, 0.11332, 0.09693, 0.10006])
assert np.abs(image_slice - expected_slice).max() < 1e-4
def test_stable_diffusion_attention_slicing(self):
torch.cuda.reset_peak_memory_stats()
model_id = "stabilityai/stable-diffusion-2-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-base", revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
# make attention efficient
# enable attention slicing
pipe.enable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images
inputs = self.get_inputs(torch_device, dtype=torch.float16)
image_sliced = pipe(**inputs).images
mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 3.75 GB is allocated
assert mem_bytes < 3.75 * 10**9
# make sure that less than 3.3 GB is allocated
assert mem_bytes < 3.3 * 10**9
# disable chunking
# disable slicing
pipe.disable_attention_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images
inputs = self.get_inputs(torch_device, dtype=torch.float16)
image = pipe(**inputs).images
# make sure that more than 3.75 GB is allocated
# make sure that more than 3.3 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
def test_stable_diffusion_same_quality(self):
torch.cuda.reset_peak_memory_stats()
model_id = "stabilityai/stable-diffusion-2-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(torch_device)
pipe.enable_attention_slicing()
pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
output_chunked = pipe(
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")
image = output.images
# Make sure results are close enough
diff = np.abs(image_chunked.flatten() - image.flatten())
# They ARE different since ops are not run always at the same precision
# however, they should be extremely close.
assert diff.mean() < 5e-2
def test_stable_diffusion_text2img_pipeline_default(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-text2img/astronaut_riding_a_horse.npy"
)
model_id = "stabilityai/stable-diffusion-2-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 5e-3
assert mem_bytes > 3.3 * 10**9
assert np.abs(image_sliced - image).max() < 1e-3
def test_stable_diffusion_text2img_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 0:
if step == 1:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([1.8606, 1.3169, -0.0691, 1.2374, -2.309, 1.077, -0.1084, -0.6774, -2.9594])
expected_slice = np.array([-0.3857, -0.4507, -1.167, 0.074, -1.108, 0.7183, -1.822, 0.1915, 1.283])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
elif step == 20:
elif step == 2:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([1.0757, 1.1860, 1.1410, 0.4645, -0.2476, 0.6100, -0.7755, -0.8841, -0.9497])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
expected_slice = np.array([0.268, -0.2095, -0.7744, -0.541, -0.79, 0.3926, -0.7754, 0.465, 1.291])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
test_callback_fn.has_been_called = False
callback_fn.has_been_called = False
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-base", revision="fp16", torch_dtype=torch.float16
@ -411,37 +358,139 @@ class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "Andromeda galaxy in a bottle"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
pipe(
prompt=prompt,
num_inference_steps=20,
guidance_scale=7.5,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 20
inputs = self.get_inputs(torch_device, dtype=torch.float16)
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == inputs["num_inference_steps"]
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipeline_id = "stabilityai/stable-diffusion-2-base"
prompt = "Andromeda galaxy in a bottle"
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-base", revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
pipeline = pipeline.to(torch_device)
pipeline.enable_attention_slicing(1)
pipeline.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipeline(prompt, generator=generator, num_inference_steps=5)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.8 GB is allocated
assert mem_bytes < 2.8 * 10**9
@nightly
@require_torch_gpu
class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "a photograph of an astronaut riding a horse",
"latents": latents,
"generator": generator,
"num_inference_steps": 50,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
def test_stable_diffusion_2_0_default_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base").to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_2_text2img/stable_diffusion_2_0_base_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_2_1_default_pndm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_2_text2img/stable_diffusion_2_1_base_pndm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(torch_device)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_2_text2img/stable_diffusion_2_1_base_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_lms(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(torch_device)
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_2_text2img/stable_diffusion_2_1_base_lms.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_euler(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(torch_device)
sd_pipe.scheduler = EulerDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_2_text2img/stable_diffusion_2_1_base_euler.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_stable_diffusion_dpm(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(torch_device)
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_2_text2img/stable_diffusion_2_1_base_dpm_multi.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3

View File

@ -24,12 +24,13 @@ import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionDepth2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import require_torch_gpu
from PIL import Image
@ -49,7 +50,7 @@ torch.backends.cuda.matmul.allow_tf32 = False
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionDepth2ImgPipeline
test_save_load_optional_components = False
@ -275,12 +276,12 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
def test_stable_diffusion_depth2img_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionDepth2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
pipe = StableDiffusionDepth2ImgPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
@ -293,13 +294,13 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
def test_stable_diffusion_depth2img_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionDepth2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
pipe = StableDiffusionDepth2ImgPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
output = pipe(**inputs, negative_prompt=negative_prompt)
image = output.images
image_slice = image[0, -3:, -3:, -1]
@ -313,14 +314,14 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
def test_stable_diffusion_depth2img_multiple_init_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionDepth2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
pipe = StableDiffusionDepth2ImgPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = 2 * [inputs["image"]]
image = sd_pipe(**inputs).images
image = pipe(**inputs).images
image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 32, 32, 3)
@ -334,13 +335,13 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
def test_stable_diffusion_depth2img_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionDepth2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
pipe = StableDiffusionDepth2ImgPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs).images
images = pipe(**inputs).images
assert images.shape == (1, 32, 32, 3)
@ -348,14 +349,14 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs).images
images = pipe(**inputs).images
assert images.shape == (batch_size, 32, 32, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (num_images_per_prompt, 32, 32, 3)
@ -363,20 +364,20 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
def test_stable_diffusion_depth2img_pil(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionDepth2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
pipe = StableDiffusionDepth2ImgPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
if torch_device == "mps":
@ -388,186 +389,217 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
@slow
@require_torch_gpu
class StableDiffusionDepth2ImgPipelineIntegrationTests(unittest.TestCase):
class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_depth2img_pipeline_default(self):
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.npy"
)
inputs = {
"prompt": "two tigers",
"image": init_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
model_id = "stabilityai/stable-diffusion-2-depth"
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(model_id)
def test_stable_diffusion_depth2img_pipeline_default(self):
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth", safety_checker=None
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "two tigers"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (480, 640, 3)
# depth2img is flaky across GPUs even in fp32, so using MAE here
assert np.abs(expected_image - image).max() < 1e-3
assert image.shape == (1, 480, 640, 3)
expected_slice = np.array([0.75446, 0.74692, 0.75951, 0.81611, 0.80593, 0.79992, 0.90529, 0.87921, 0.86903])
assert np.abs(expected_slice - image_slice).max() < 1e-4
def test_stable_diffusion_depth2img_pipeline_k_lms(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth", safety_checker=None
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats_k_lms.npy"
)
model_id = "stabilityai/stable-diffusion-2-depth"
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(model_id, scheduler=lms)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "two tigers"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (480, 640, 3)
assert np.abs(expected_image - image).max() < 5e-3
assert image.shape == (1, 480, 640, 3)
expected_slice = np.array([0.63957, 0.64879, 0.65668, 0.64385, 0.67078, 0.63588, 0.66577, 0.62180, 0.66286])
assert np.abs(expected_slice - image_slice).max() < 1e-4
def test_stable_diffusion_depth2img_pipeline_ddim(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth", safety_checker=None
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats_ddim.npy"
)
model_id = "stabilityai/stable-diffusion-2-depth"
ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(model_id, scheduler=ddim)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "two tigers"
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, 253:256, 253:256, -1].flatten()
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (480, 640, 3)
assert np.abs(expected_image - image).max() < 1e-3
assert image.shape == (1, 480, 640, 3)
expected_slice = np.array([0.62840, 0.64191, 0.62953, 0.63653, 0.64205, 0.61574, 0.62252, 0.65827, 0.64809])
assert np.abs(expected_slice - image_slice).max() < 1e-4
def test_stable_diffusion_depth2img_intermediate_state(self):
number_of_steps = 0
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
test_callback_fn.has_been_called = True
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 1:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 60, 80)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.7825, 0.5786, -0.9125, -0.9885, -1.0071, 2.7126, -0.8490, 0.3776, -0.0791]
)
expected_slice = np.array([-1.148, -0.2079, -0.622, -2.477, -2.348, 0.3828, -2.055, -1.569, -1.526])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
elif step == 37:
elif step == 2:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 60, 80)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.6110, -0.2347, -0.5115, -1.1383, -1.4755, -0.5970, -0.9050, -0.7199, -0.8417]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
expected_slice = np.array([-1.145, -0.2063, -0.6216, -2.469, -2.344, 0.3794, -2.05, -1.57, -1.521])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
callback_fn.has_been_called = False
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth", safety_checker=None, revision="fp16", torch_dtype=torch.float16
)
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-depth")
pipe.to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "two tigers"
generator = torch.Generator(device=torch_device).manual_seed(0)
pipe(
prompt=prompt,
image=init_image,
strength=0.75,
num_inference_steps=50,
generator=generator,
callback=test_callback_fn,
callback_steps=1,
)
assert test_callback_fn.has_been_called
assert number_of_steps == 37
inputs = self.get_inputs(torch_device, dtype=torch.float16)
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == 2
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
)
init_image = init_image.resize((768, 512))
model_id = "stabilityai/stable-diffusion-2-depth"
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
model_id, scheduler=lms, safety_checker=None, revision="fp16", torch_dtype=torch.float16
"stabilityai/stable-diffusion-2-depth", safety_checker=None, revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
num_inference_steps=2,
)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.9 GB is allocated
assert mem_bytes < 2.9 * 10**9
@nightly
@require_torch_gpu
class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
)
inputs = {
"prompt": "two tigers",
"image": init_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "numpy",
}
return inputs
def test_depth2img_pndm(self):
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-depth")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_depth2img/stable_diffusion_2_0_pndm.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_depth2img_ddim(self):
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-depth")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_depth2img/stable_diffusion_2_0_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_img2img_lms(self):
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-depth")
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_depth2img/stable_diffusion_2_0_lms.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
def test_img2img_dpm(self):
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-depth")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 30
image = pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_depth2img/stable_diffusion_2_0_dpm_multi.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3

View File

@ -47,7 +47,7 @@ from diffusers import (
)
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
@ -674,6 +674,7 @@ class PipelineFastTests(unittest.TestCase):
@slow
@require_torch_gpu
class PipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
@ -816,6 +817,16 @@ class PipelineSlowTests(unittest.TestCase):
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
@nightly
@require_torch_gpu
class PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_ddpm_ddim_equality_batched(self):
seed = 0
model_id = "google/ddpm-cifar10-32"