default fast model loading 🔥 (#1115)
* make accelerate hard dep * default fast init * move params to cpu when device map is None * handle device_map=None * handle torch < 1.9 * remove device_map="auto" * style * add accelerate in torch extra * remove accelerate from extras["test"] * raise an error if torch is available but not accelerate * update installation docs * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * improve defautl loading speed even further, allow disabling fats loading * address review comments * adapt the tests * fix test_stable_diffusion_fast_load * fix test_read_init * temp fix for dummy checks * Trigger Build * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
parent
ef2ea33c3b
commit
7482178162
16
README.md
16
README.md
|
@ -27,10 +27,12 @@ More precisely, 🤗 Diffusers offers:
|
|||
|
||||
## Installation
|
||||
|
||||
### For PyTorch
|
||||
|
||||
**With `pip`**
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers
|
||||
pip install --upgrade diffusers[torch]
|
||||
```
|
||||
|
||||
**With `conda`**
|
||||
|
@ -39,6 +41,14 @@ pip install --upgrade diffusers
|
|||
conda install -c conda-forge diffusers
|
||||
```
|
||||
|
||||
### For Flax
|
||||
|
||||
**With `pip`**
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers[flax]
|
||||
```
|
||||
|
||||
**Apple Silicon (M1/M2) support**
|
||||
|
||||
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
|
||||
|
@ -354,7 +364,7 @@ There are many ways to try running Diffusers! Here we outline code-focused tools
|
|||
If you want to run the code yourself 💻, you can try out:
|
||||
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
||||
```python
|
||||
# !pip install diffusers transformers
|
||||
# !pip install diffusers["torch"] transformers
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
device = "cuda"
|
||||
|
@ -373,7 +383,7 @@ image.save("squirrel.png")
|
|||
```
|
||||
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
|
||||
```python
|
||||
# !pip install diffusers
|
||||
# !pip install diffusers["torch"]
|
||||
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
||||
|
||||
model_id = "google/ddpm-celebahq-256"
|
||||
|
|
|
@ -12,9 +12,12 @@ specific language governing permissions and limitations under the License.
|
|||
|
||||
# Installation
|
||||
|
||||
Install Diffusers for with PyTorch. Support for other libraries will come in the future
|
||||
Install 🤗 Diffusers for whichever deep learning library you’re working with.
|
||||
|
||||
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
|
||||
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
|
||||
## Install with pip
|
||||
|
||||
|
@ -36,12 +39,30 @@ source .env/bin/activate
|
|||
|
||||
Now you're ready to install 🤗 Diffusers with the following command:
|
||||
|
||||
**For PyTorch**
|
||||
|
||||
```bash
|
||||
pip install diffusers
|
||||
pip install diffusers["torch"]
|
||||
```
|
||||
|
||||
**For Flax**
|
||||
|
||||
```bash
|
||||
pip install diffusers["flax"]
|
||||
```
|
||||
|
||||
## Install from source
|
||||
|
||||
Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed.
|
||||
|
||||
For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally).
|
||||
|
||||
To install `accelerate`
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
Install 🤗 Diffusers from source with the following command:
|
||||
|
||||
```bash
|
||||
|
@ -67,7 +88,18 @@ Clone the repository and install 🤗 Diffusers with the following commands:
|
|||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
**For PyTorch**
|
||||
|
||||
```
|
||||
pip install -e ".[torch]"
|
||||
```
|
||||
|
||||
**For Flax**
|
||||
|
||||
```
|
||||
pip install -e ".[flax]"
|
||||
```
|
||||
|
||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||
|
|
3
setup.py
3
setup.py
|
@ -178,7 +178,6 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
|||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||
extras["test"] = deps_list(
|
||||
"accelerate",
|
||||
"datasets",
|
||||
"parameterized",
|
||||
"pytest",
|
||||
|
@ -188,7 +187,7 @@ extras["test"] = deps_list(
|
|||
"torchvision",
|
||||
"transformers"
|
||||
)
|
||||
extras["torch"] = deps_list("torch")
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .utils import (
|
||||
is_accelerate_available,
|
||||
is_flax_available,
|
||||
is_inflect_available,
|
||||
is_onnx_available,
|
||||
|
@ -16,6 +17,13 @@ from .onnx_utils import OnnxRuntimeModel
|
|||
from .utils import logging
|
||||
|
||||
|
||||
# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
|
||||
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
|
||||
if is_torch_available() and not is_accelerate_available():
|
||||
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
|
||||
raise ImportError(error_msg)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
|
|
|
@ -21,7 +21,9 @@ from typing import Callable, List, Optional, Tuple, Union
|
|||
import torch
|
||||
from torch import Tensor, device
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
import accelerate
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
@ -268,6 +270,19 @@ class ModelMixin(torch.nn.Module):
|
|||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
fast_load (`bool`, *optional*, defaults to `True`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
this argument will be ignored and the model will be loaded normally.
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -296,6 +311,16 @@ class ModelMixin(torch.nn.Module):
|
|||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
fast_load = kwargs.pop("fast_load", True)
|
||||
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
|
||||
|
||||
# Fast init is only possible if torch version is >= 1.9.0
|
||||
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
|
||||
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
|
||||
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
|
@ -378,12 +403,8 @@ class ModelMixin(torch.nn.Module):
|
|||
|
||||
# restore default dtype
|
||||
|
||||
if device_map == "auto":
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
if _INIT_EMPTY_WEIGHTS:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model, unused_kwargs = cls.from_config(
|
||||
config_path,
|
||||
|
@ -400,7 +421,17 @@ class ModelMixin(torch.nn.Module):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
||||
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file)
|
||||
# move the parms from meta device to cpu
|
||||
for param_name, param in state_dict.items():
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by deafult the device_map is None and the weights are loaded on the CPU
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
|
|
|
@ -380,6 +380,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
fast_load = kwargs.pop("fast_load", True)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
|
@ -572,6 +573,15 @@ class DiffusionPipeline(ConfigMixin):
|
|||
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["fast_load"] = fast_load
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_transformers_model and device_map is None:
|
||||
loading_kwargs["low_cpu_mem_usage"] = fast_load
|
||||
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
|
||||
|
|
|
@ -0,0 +1,392 @@
|
|||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class ModelMixin(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class AutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class UNet1DModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class UNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class UNet2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class VQModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_constant_schedule(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_constant_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_linear_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
|
||||
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"])
|
||||
|
||||
|
||||
def get_scheduler(*args, **kwargs):
|
||||
requires_backends(get_scheduler, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DanceDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DDIMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DDPMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class KarrasVePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class LDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class PNDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class DDPMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class EulerDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class IPNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class KarrasVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class PNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class SchedulerMixin(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
|
||||
class EMAModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "accelerate"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "accelerate"])
|
|
@ -28,7 +28,7 @@ class UnetModel1DTests(unittest.TestCase):
|
|||
@slow
|
||||
def test_unet_1d_maestro(self):
|
||||
model_id = "harmonai/maestro-150k"
|
||||
model = UNet1DModel.from_pretrained(model_id, subfolder="unet", device_map="auto")
|
||||
model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
|
||||
model.to(torch_device)
|
||||
|
||||
sample_size = 65536
|
||||
|
|
|
@ -125,9 +125,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_from_pretrained_accelerate(self):
|
||||
model, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input).sample
|
||||
|
||||
|
@ -135,9 +133,8 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_from_pretrained_accelerate_wont_change_results(self):
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
# by defautl model loading will use accelerate as `fast_load=True`
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model_accelerate.to(torch_device)
|
||||
model_accelerate.eval()
|
||||
|
||||
|
@ -159,7 +156,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
gc.collect()
|
||||
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
|
||||
)
|
||||
model_normal_load.to(torch_device)
|
||||
model_normal_load.eval()
|
||||
|
@ -173,9 +170,8 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
gc.collect()
|
||||
|
||||
tracemalloc.start()
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
# by defautl model loading will use accelerate as `fast_load=True`
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model_accelerate.to(torch_device)
|
||||
model_accelerate.eval()
|
||||
_, peak_accelerate = tracemalloc.get_traced_memory()
|
||||
|
@ -184,7 +180,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
|
||||
)
|
||||
model_normal_load.to(torch_device)
|
||||
model_normal_load.eval()
|
||||
_, peak_normal = tracemalloc.get_traced_memory()
|
||||
|
@ -348,9 +346,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained(
|
||||
"google/ncsnpp-celebahq-256", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
|
@ -364,7 +360,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_output_pretrained_ve_mid(self):
|
||||
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", device_map="auto")
|
||||
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
@ -439,7 +435,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = UNet2DConditionModel.from_pretrained(
|
||||
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision, device_map="auto"
|
||||
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
|
|
|
@ -155,7 +155,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = AutoencoderKL.from_pretrained(
|
||||
model_id, subfolder="vae", torch_dtype=torch_dtype, revision=revision, device_map="auto"
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch_dtype,
|
||||
revision=revision,
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
|
|||
def test_dance_diffusion(self):
|
||||
device = torch_device
|
||||
|
||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto")
|
||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -103,9 +103,7 @@ class PipelineIntegrationTests(unittest.TestCase):
|
|||
def test_dance_diffusion_fp16(self):
|
||||
device = torch_device
|
||||
|
||||
pipe = DanceDiffusionPipeline.from_pretrained(
|
||||
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
|||
def test_inference_ema_bedroom(self):
|
||||
model_id = "google/ddpm-ema-bedroom-256"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
|||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
|
|
@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
|||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
|
|
@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
model_id = "google/ncsnpp-celebahq-256"
|
||||
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = KarrasVeScheduler()
|
||||
|
||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||
|
|
|
@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||
@require_torch
|
||||
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_text2img(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
|||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_text2img_fast(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
|
|||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = PNDMScheduler()
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
|
|
@ -72,7 +72,7 @@ class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
model_id = "google/ncsnpp-church-256"
|
||||
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||
|
||||
|
|
|
@ -631,7 +631,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
|
||||
def test_stable_diffusion(self):
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -653,9 +653,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
def test_stable_diffusion_fast_ddim(self):
|
||||
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
|
||||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-1", scheduler=scheduler, device_map="auto"
|
||||
)
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -674,7 +672,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
|
||||
def test_lms_stable_diffusion_pipeline(self):
|
||||
model_id = "CompVis/stable-diffusion-v1-1"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, device_map="auto").to(torch_device)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pipe.scheduler = scheduler
|
||||
|
@ -693,9 +691,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
def test_stable_diffusion_memory_chunking(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -732,9 +728,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
def test_stable_diffusion_text2img_pipeline_fp16(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_id, revision="fp16", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -767,11 +761,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
device_map="auto",
|
||||
)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
@ -812,7 +802,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
@ -833,23 +823,23 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
|
||||
def test_stable_diffusion_accelerate_auto_device(self):
|
||||
def test_stable_diffusion_fast_load(self):
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
start_time = time.time()
|
||||
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_fast_load = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline_normal_load.to(torch_device)
|
||||
normal_load_time = time.time() - start_time
|
||||
pipeline_fast_load.to(torch_device)
|
||||
fast_load_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False
|
||||
)
|
||||
meta_device_load_time = time.time() - start_time
|
||||
normal_load_time = time.time() - start_time
|
||||
|
||||
assert 2 * meta_device_load_time < normal_load_time
|
||||
assert 2 * fast_load_time < normal_load_time
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
|
||||
|
|
|
@ -488,7 +488,6 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
|||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
device_map="auto",
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
@ -529,7 +528,6 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
|||
model_id,
|
||||
scheduler=lms,
|
||||
safety_checker=None,
|
||||
device_map="auto",
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
@ -581,7 +579,9 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
|||
init_image = init_image.resize((768, 512))
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
|
|
@ -284,11 +284,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
model_id = "runwayml/stable-diffusion-inpainting"
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
device_map="auto",
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
@ -328,7 +324,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
|||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None,
|
||||
device_map="auto",
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
@ -365,9 +360,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
|||
|
||||
model_id = "runwayml/stable-diffusion-inpainting"
|
||||
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
|
|
@ -364,11 +364,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
device_map="auto",
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
@ -411,7 +407,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
|||
model_id,
|
||||
scheduler=lms,
|
||||
safety_checker=None,
|
||||
device_map="auto",
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
@ -468,7 +463,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
|
|
@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase):
|
|||
def test_read_init(self):
|
||||
objects = read_init()
|
||||
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
|
||||
self.assertIn("torch", objects)
|
||||
self.assertIn("torch_and_accelerate", objects)
|
||||
self.assertIn("torch_and_transformers", objects)
|
||||
self.assertIn("flax_and_transformers", objects)
|
||||
self.assertIn("torch_and_transformers_and_onnx", objects)
|
||||
|
||||
# Likewise, we can't assert on the exact content of a key
|
||||
self.assertIn("UNet2DModel", objects["torch"])
|
||||
self.assertIn("UNet2DModel", objects["torch_and_accelerate"])
|
||||
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
|
||||
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
|
||||
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
|
||||
|
|
|
@ -128,7 +128,7 @@ class CustomPipelineTests(unittest.TestCase):
|
|||
def test_load_pipeline_from_git(self):
|
||||
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto")
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
||||
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
@ -138,7 +138,6 @@ class CustomPipelineTests(unittest.TestCase):
|
|||
feature_extractor=feature_extractor,
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16",
|
||||
device_map="auto",
|
||||
)
|
||||
pipeline.enable_attention_slicing()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
|
@ -333,9 +332,7 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
def test_smart_download(self):
|
||||
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
model_id, cache_dir=tmpdirname, force_download=True, device_map="auto"
|
||||
)
|
||||
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
|
||||
local_repo_name = "--".join(["models"] + model_id.split("/"))
|
||||
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
|
||||
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
|
||||
|
@ -359,7 +356,10 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
DiffusionPipeline.from_pretrained(
|
||||
model_id, not_used=True, cache_dir=tmpdirname, force_download=True, device_map="auto"
|
||||
model_id,
|
||||
not_used=True,
|
||||
cache_dir=tmpdirname,
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
||||
|
@ -383,7 +383,7 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ddpm.save_pretrained(tmpdirname)
|
||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname, device_map="auto")
|
||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
|
||||
new_ddpm.to(torch_device)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
@ -399,11 +399,11 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
|
||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
|
||||
ddpm = ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
|
||||
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
||||
ddpm_from_hub.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -421,14 +421,12 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
# pass unet into DiffusionPipeline
|
||||
unet = UNet2DModel.from_pretrained(model_path, device_map="auto")
|
||||
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(
|
||||
model_path, unet=unet, scheduler=scheduler, device_map="auto"
|
||||
)
|
||||
unet = UNet2DModel.from_pretrained(model_path)
|
||||
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
|
||||
ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
|
||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
|
||||
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -443,7 +441,7 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
pipe = DDIMPipeline.from_pretrained(model_path, device_map="auto")
|
||||
pipe = DDIMPipeline.from_pretrained(model_path)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
@ -467,7 +465,7 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
def test_ddpm_ddim_equality(self, seed):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler()
|
||||
ddim_scheduler = DDIMScheduler()
|
||||
|
||||
|
@ -498,7 +496,7 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
def test_ddpm_ddim_equality_batched(self, seed):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler()
|
||||
ddim_scheduler = DDIMScheduler()
|
||||
|
||||
|
|
Loading…
Reference in New Issue