default fast model loading 🔥 (#1115)

* make accelerate hard dep

* default fast init

* move params to cpu when device map is None

* handle device_map=None

* handle torch < 1.9

* remove device_map="auto"

* style

* add accelerate in torch extra

* remove accelerate from extras["test"]

* raise an error if torch is available but not accelerate

* update installation docs

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* improve defautl loading speed even further, allow disabling fats loading

* address review comments

* adapt the tests

* fix test_stable_diffusion_fast_load

* fix test_read_init

* temp fix for dummy checks

* Trigger Build

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <anton@huggingface.co>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Suraj Patil 2022-11-03 17:25:57 +01:00 committed by GitHub
parent ef2ea33c3b
commit 7482178162
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 564 additions and 109 deletions

View File

@ -27,10 +27,12 @@ More precisely, 🤗 Diffusers offers:
## Installation ## Installation
### For PyTorch
**With `pip`** **With `pip`**
```bash ```bash
pip install --upgrade diffusers pip install --upgrade diffusers[torch]
``` ```
**With `conda`** **With `conda`**
@ -39,6 +41,14 @@ pip install --upgrade diffusers
conda install -c conda-forge diffusers conda install -c conda-forge diffusers
``` ```
### For Flax
**With `pip`**
```bash
pip install --upgrade diffusers[flax]
```
**Apple Silicon (M1/M2) support** **Apple Silicon (M1/M2) support**
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps). Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
@ -354,7 +364,7 @@ There are many ways to try running Diffusers! Here we outline code-focused tools
If you want to run the code yourself 💻, you can try out: If you want to run the code yourself 💻, you can try out:
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256) - [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
```python ```python
# !pip install diffusers transformers # !pip install diffusers["torch"] transformers
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
device = "cuda" device = "cuda"
@ -373,7 +383,7 @@ image.save("squirrel.png")
``` ```
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256) - [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
```python ```python
# !pip install diffusers # !pip install diffusers["torch"]
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
model_id = "google/ddpm-celebahq-256" model_id = "google/ddpm-celebahq-256"

View File

@ -12,9 +12,12 @@ specific language governing permissions and limitations under the License.
# Installation # Installation
Install Diffusers for with PyTorch. Support for other libraries will come in the future Install 🤗 Diffusers for whichever deep learning library youre working with.
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+. 🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using:
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
## Install with pip ## Install with pip
@ -36,12 +39,30 @@ source .env/bin/activate
Now you're ready to install 🤗 Diffusers with the following command: Now you're ready to install 🤗 Diffusers with the following command:
**For PyTorch**
```bash ```bash
pip install diffusers pip install diffusers["torch"]
```
**For Flax**
```bash
pip install diffusers["flax"]
``` ```
## Install from source ## Install from source
Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed.
For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally).
To install `accelerate`
```bash
pip install accelerate
```
Install 🤗 Diffusers from source with the following command: Install 🤗 Diffusers from source with the following command:
```bash ```bash
@ -67,7 +88,18 @@ Clone the repository and install 🤗 Diffusers with the following commands:
```bash ```bash
git clone https://github.com/huggingface/diffusers.git git clone https://github.com/huggingface/diffusers.git
cd diffusers cd diffusers
pip install -e . ```
**For PyTorch**
```
pip install -e ".[torch]"
```
**For Flax**
```
pip install -e ".[flax]"
``` ```
These commands will link the folder you cloned the repository to and your Python library paths. These commands will link the folder you cloned the repository to and your Python library paths.

View File

@ -178,7 +178,6 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list( extras["test"] = deps_list(
"accelerate",
"datasets", "datasets",
"parameterized", "parameterized",
"pytest", "pytest",
@ -188,7 +187,7 @@ extras["test"] = deps_list(
"torchvision", "torchvision",
"transformers" "transformers"
) )
extras["torch"] = deps_list("torch") extras["torch"] = deps_list("torch", "accelerate")
if os.name == "nt": # windows if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows extras["flax"] = [] # jax is not supported on windows

View File

@ -1,4 +1,5 @@
from .utils import ( from .utils import (
is_accelerate_available,
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_onnx_available, is_onnx_available,
@ -16,6 +17,13 @@ from .onnx_utils import OnnxRuntimeModel
from .utils import logging from .utils import logging
# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
if is_torch_available() and not is_accelerate_available():
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
raise ImportError(error_msg)
if is_torch_available(): if is_torch_available():
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel

View File

@ -21,7 +21,9 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, device from torch import Tensor, device
from diffusers.utils import is_accelerate_available import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
@ -268,6 +270,19 @@ class ModelMixin(torch.nn.Module):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. Please refer to the mirror site for more information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
fast_load (`bool`, *optional*, defaults to `True`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
this argument will be ignored and the model will be loaded normally.
<Tip> <Tip>
@ -296,6 +311,16 @@ class ModelMixin(torch.nn.Module):
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True)
# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
# Fast init is only possible if torch version is >= 1.9.0
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")
user_agent = { user_agent = {
"diffusers": __version__, "diffusers": __version__,
@ -378,12 +403,8 @@ class ModelMixin(torch.nn.Module):
# restore default dtype # restore default dtype
if device_map == "auto": if _INIT_EMPTY_WEIGHTS:
if is_accelerate_available(): # Instantiate model with empty weights
import accelerate
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config( model, unused_kwargs = cls.from_config(
config_path, config_path,
@ -400,7 +421,17 @@ class ModelMixin(torch.nn.Module):
**kwargs, **kwargs,
) )
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) # if device_map is Non,e load the state dict on move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file)
# move the parms from meta device to cpu
for param_name, param in state_dict.items():
set_module_tensor_to_device(model, param_name, param_device, value=param)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by deafult the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
loading_info = { loading_info = {
"missing_keys": [], "missing_keys": [],

View File

@ -380,6 +380,7 @@ class DiffusionPipeline(ConfigMixin):
provider = kwargs.pop("provider", None) provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None) sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
@ -572,6 +573,15 @@ class DiffusionPipeline(ConfigMixin):
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
) )
if is_diffusers_model:
loading_kwargs["fast_load"] = fast_load
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_transformers_model and device_map is None:
loading_kwargs["low_cpu_mem_usage"] = fast_load
if is_diffusers_model or is_transformers_model: if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map loading_kwargs["device_map"] = device_map

View File

@ -0,0 +1,392 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class ModelMixin(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet2DConditionModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet2DModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class VQModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch", "accelerate"])
def get_constant_schedule_with_warmup(*args, **kwargs):
requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"])
def get_cosine_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"])
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"])
def get_linear_schedule_with_warmup(*args, **kwargs):
requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"])
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"])
def get_scheduler(*args, **kwargs):
requires_backends(get_scheduler, ["torch", "accelerate"])
class DiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DanceDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDIMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDPMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class KarrasVePipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class LDMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDPMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EulerDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class IPNDMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class KarrasVeScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EMAModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])

View File

@ -28,7 +28,7 @@ class UnetModel1DTests(unittest.TestCase):
@slow @slow
def test_unet_1d_maestro(self): def test_unet_1d_maestro(self):
model_id = "harmonai/maestro-150k" model_id = "harmonai/maestro-150k"
model = UNet1DModel.from_pretrained(model_id, subfolder="unet", device_map="auto") model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
model.to(torch_device) model.to(torch_device)
sample_size = 65536 sample_size = 65536

View File

@ -125,9 +125,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate(self): def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained( model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model.to(torch_device) model.to(torch_device)
image = model(**self.dummy_input).sample image = model(**self.dummy_input).sample
@ -135,9 +133,8 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate_wont_change_results(self): def test_from_pretrained_accelerate_wont_change_results(self):
model_accelerate, _ = UNet2DModel.from_pretrained( # by defautl model loading will use accelerate as `fast_load=True`
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
)
model_accelerate.to(torch_device) model_accelerate.to(torch_device)
model_accelerate.eval() model_accelerate.eval()
@ -159,7 +156,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect() gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained( model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" "fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
) )
model_normal_load.to(torch_device) model_normal_load.to(torch_device)
model_normal_load.eval() model_normal_load.eval()
@ -173,9 +170,8 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect() gc.collect()
tracemalloc.start() tracemalloc.start()
model_accelerate, _ = UNet2DModel.from_pretrained( # by defautl model loading will use accelerate as `fast_load=True`
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto" model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
)
model_accelerate.to(torch_device) model_accelerate.to(torch_device)
model_accelerate.eval() model_accelerate.eval()
_, peak_accelerate = tracemalloc.get_traced_memory() _, peak_accelerate = tracemalloc.get_traced_memory()
@ -184,7 +180,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True) model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
)
model_normal_load.to(torch_device) model_normal_load.to(torch_device)
model_normal_load.eval() model_normal_load.eval()
_, peak_normal = tracemalloc.get_traced_memory() _, peak_normal = tracemalloc.get_traced_memory()
@ -348,9 +346,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@slow @slow
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained( model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
"google/ncsnpp-celebahq-256", output_loading_info=True, device_map="auto"
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
@ -364,7 +360,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@slow @slow
def test_output_pretrained_ve_mid(self): def test_output_pretrained_ve_mid(self):
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", device_map="auto") model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
@ -439,7 +435,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
torch_dtype = torch.float16 if fp16 else torch.float32 torch_dtype = torch.float16 if fp16 else torch.float32
model = UNet2DConditionModel.from_pretrained( model = UNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision, device_map="auto" model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
) )
model.to(torch_device).eval() model.to(torch_device).eval()

View File

@ -155,7 +155,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
torch_dtype = torch.float16 if fp16 else torch.float32 torch_dtype = torch.float16 if fp16 else torch.float32
model = AutoencoderKL.from_pretrained( model = AutoencoderKL.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch_dtype, revision=revision, device_map="auto" model_id,
subfolder="vae",
torch_dtype=torch_dtype,
revision=revision,
) )
model.to(torch_device).eval() model.to(torch_device).eval()

View File

@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
def test_dance_diffusion(self): def test_dance_diffusion(self):
device = torch_device device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto") pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
pipe = pipe.to(device) pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -103,9 +103,7 @@ class PipelineIntegrationTests(unittest.TestCase):
def test_dance_diffusion_fp16(self): def test_dance_diffusion_fp16(self):
device = torch_device device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained( pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
)
pipe = pipe.to(device) pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)

View File

@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_ema_bedroom(self): def test_inference_ema_bedroom(self):
model_id = "google/ddpm-ema-bedroom-256" model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto") unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler.from_config(model_id) scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto") unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler() scheduler = DDIMScheduler()
ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim = DDIMPipeline(unet=unet, scheduler=scheduler)

View File

@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto") unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id) scheduler = DDPMScheduler.from_config(model_id)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)

View File

@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class KarrasVePipelineIntegrationTests(unittest.TestCase): class KarrasVePipelineIntegrationTests(unittest.TestCase):
def test_inference(self): def test_inference(self):
model_id = "google/ncsnpp-celebahq-256" model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id, device_map="auto") model = UNet2DModel.from_pretrained(model_id)
scheduler = KarrasVeScheduler() scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=model, scheduler=scheduler) pipe = KarrasVePipeline(unet=model, scheduler=scheduler)

View File

@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@require_torch @require_torch
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
def test_inference_text2img(self): def test_inference_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto") ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_text2img_fast(self): def test_inference_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto") ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)

View File

@ -71,7 +71,7 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto") unet = UNet2DModel.from_pretrained(model_id)
scheduler = PNDMScheduler() scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)

View File

@ -72,7 +72,7 @@ class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase): class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
def test_inference(self): def test_inference(self):
model_id = "google/ncsnpp-church-256" model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id, device_map="auto") model = UNet2DModel.from_pretrained(model_id)
scheduler = ScoreSdeVeScheduler.from_config(model_id) scheduler = ScoreSdeVeScheduler.from_config(model_id)

View File

@ -631,7 +631,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion(self): def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk # make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto") sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
@ -653,9 +653,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_fast_ddim(self): def test_stable_diffusion_fast_ddim(self):
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler") scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained( sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
"CompVis/stable-diffusion-v1-1", scheduler=scheduler, device_map="auto"
)
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
@ -674,7 +672,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_lms_stable_diffusion_pipeline(self): def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1" model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, device_map="auto").to(torch_device) pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe.scheduler = scheduler pipe.scheduler = scheduler
@ -693,9 +691,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_memory_chunking(self): def test_stable_diffusion_memory_chunking(self):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
model_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -732,9 +728,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_text2img_pipeline_fp16(self): def test_stable_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
model_id, revision="fp16", device_map="auto", torch_dtype=torch.float16
)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -767,11 +761,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
expected_image = np.array(expected_image, dtype=np.float32) / 255.0 expected_image = np.array(expected_image, dtype=np.float32) / 255.0
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None)
model_id,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()
@ -812,7 +802,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
test_callback_fn.has_been_called = False test_callback_fn.has_been_called = False
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto" "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
) )
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -833,23 +823,23 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert test_callback_fn.has_been_called assert test_callback_fn.has_been_called
assert number_of_steps == 51 assert number_of_steps == 51
def test_stable_diffusion_accelerate_auto_device(self): def test_stable_diffusion_fast_load(self):
pipeline_id = "CompVis/stable-diffusion-v1-4" pipeline_id = "CompVis/stable-diffusion-v1-4"
start_time = time.time() start_time = time.time()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained( pipeline_fast_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16 pipeline_id, revision="fp16", torch_dtype=torch.float16
) )
pipeline_normal_load.to(torch_device) pipeline_fast_load.to(torch_device)
normal_load_time = time.time() - start_time fast_load_time = time.time() - start_time
start_time = time.time() start_time = time.time()
_ = StableDiffusionPipeline.from_pretrained( _ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False
) )
meta_device_load_time = time.time() - start_time normal_load_time = time.time() - start_time
assert 2 * meta_device_load_time < normal_load_time assert 2 * fast_load_time < normal_load_time
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self): def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):

View File

@ -488,7 +488,6 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, model_id,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -529,7 +528,6 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
model_id, model_id,
scheduler=lms, scheduler=lms,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -581,7 +579,9 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512)) init_image = init_image.resize((768, 512))
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto" "CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)

View File

@ -284,11 +284,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
) )
model_id = "runwayml/stable-diffusion-inpainting" model_id = "runwayml/stable-diffusion-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
model_id,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()
@ -328,7 +324,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
revision="fp16", revision="fp16",
torch_dtype=torch.float16, torch_dtype=torch.float16,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -365,9 +360,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
model_id = "runwayml/stable-diffusion-inpainting" model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()

View File

@ -364,11 +364,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
) )
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
model_id,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()
@ -411,7 +407,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
model_id, model_id,
scheduler=lms, scheduler=lms,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -468,7 +463,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
) )
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto" "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)

View File

@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase):
def test_read_init(self): def test_read_init(self):
objects = read_init() objects = read_init()
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects # We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
self.assertIn("torch", objects) self.assertIn("torch_and_accelerate", objects)
self.assertIn("torch_and_transformers", objects) self.assertIn("torch_and_transformers", objects)
self.assertIn("flax_and_transformers", objects) self.assertIn("flax_and_transformers", objects)
self.assertIn("torch_and_transformers_and_onnx", objects) self.assertIn("torch_and_transformers_and_onnx", objects)
# Likewise, we can't assert on the exact content of a key # Likewise, we can't assert on the exact content of a key
self.assertIn("UNet2DModel", objects["torch"]) self.assertIn("UNet2DModel", objects["torch_and_accelerate"])
self.assertIn("FlaxUNet2DConditionModel", objects["flax"]) self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"]) self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"]) self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])

View File

@ -128,7 +128,7 @@ class CustomPipelineTests(unittest.TestCase):
def test_load_pipeline_from_git(self): def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto") feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16) clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
@ -138,7 +138,6 @@ class CustomPipelineTests(unittest.TestCase):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
torch_dtype=torch.float16, torch_dtype=torch.float16,
revision="fp16", revision="fp16",
device_map="auto",
) )
pipeline.enable_attention_slicing() pipeline.enable_attention_slicing()
pipeline = pipeline.to(torch_device) pipeline = pipeline.to(torch_device)
@ -333,9 +332,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_smart_download(self): def test_smart_download(self):
model_id = "hf-internal-testing/unet-pipeline-dummy" model_id = "hf-internal-testing/unet-pipeline-dummy"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
_ = DiffusionPipeline.from_pretrained( _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
model_id, cache_dir=tmpdirname, force_download=True, device_map="auto"
)
local_repo_name = "--".join(["models"] + model_id.split("/")) local_repo_name = "--".join(["models"] + model_id.split("/"))
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
@ -359,7 +356,10 @@ class PipelineSlowTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
DiffusionPipeline.from_pretrained( DiffusionPipeline.from_pretrained(
model_id, not_used=True, cache_dir=tmpdirname, force_download=True, device_map="auto" model_id,
not_used=True,
cache_dir=tmpdirname,
force_download=True,
) )
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
@ -383,7 +383,7 @@ class PipelineSlowTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname) ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname, device_map="auto") new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device) new_ddpm.to(torch_device)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
@ -399,11 +399,11 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10) scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto") ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm = ddpm.to(torch_device) ddpm = ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto") ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub = ddpm_from_hub.to(torch_device) ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None) ddpm_from_hub.set_progress_bar_config(disable=None)
@ -421,14 +421,12 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10) scheduler = DDPMScheduler(num_train_timesteps=10)
# pass unet into DiffusionPipeline # pass unet into DiffusionPipeline
unet = UNet2DModel.from_pretrained(model_path, device_map="auto") unet = UNet2DModel.from_pretrained(model_path)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained( ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
model_path, unet=unet, scheduler=scheduler, device_map="auto"
)
ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device) ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto") ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub = ddpm_from_hub.to(torch_device) ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
@ -443,7 +441,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_output_format(self): def test_output_format(self):
model_path = "google/ddpm-cifar10-32" model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path, device_map="auto") pipe = DDIMPipeline.from_pretrained(model_path)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -467,7 +465,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_ddpm_ddim_equality(self, seed): def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto") unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler() ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler() ddim_scheduler = DDIMScheduler()
@ -498,7 +496,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_ddpm_ddim_equality_batched(self, seed): def test_ddpm_ddim_equality_batched(self, seed):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto") unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler() ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler() ddim_scheduler = DDIMScheduler()