default fast model loading 🔥 (#1115)
* make accelerate hard dep * default fast init * move params to cpu when device map is None * handle device_map=None * handle torch < 1.9 * remove device_map="auto" * style * add accelerate in torch extra * remove accelerate from extras["test"] * raise an error if torch is available but not accelerate * update installation docs * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * improve defautl loading speed even further, allow disabling fats loading * address review comments * adapt the tests * fix test_stable_diffusion_fast_load * fix test_read_init * temp fix for dummy checks * Trigger Build * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
parent
ef2ea33c3b
commit
7482178162
16
README.md
16
README.md
|
@ -27,10 +27,12 @@ More precisely, 🤗 Diffusers offers:
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
### For PyTorch
|
||||||
|
|
||||||
**With `pip`**
|
**With `pip`**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install --upgrade diffusers
|
pip install --upgrade diffusers[torch]
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**
|
**With `conda`**
|
||||||
|
@ -39,6 +41,14 @@ pip install --upgrade diffusers
|
||||||
conda install -c conda-forge diffusers
|
conda install -c conda-forge diffusers
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### For Flax
|
||||||
|
|
||||||
|
**With `pip`**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --upgrade diffusers[flax]
|
||||||
|
```
|
||||||
|
|
||||||
**Apple Silicon (M1/M2) support**
|
**Apple Silicon (M1/M2) support**
|
||||||
|
|
||||||
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
|
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
|
||||||
|
@ -354,7 +364,7 @@ There are many ways to try running Diffusers! Here we outline code-focused tools
|
||||||
If you want to run the code yourself 💻, you can try out:
|
If you want to run the code yourself 💻, you can try out:
|
||||||
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
||||||
```python
|
```python
|
||||||
# !pip install diffusers transformers
|
# !pip install diffusers["torch"] transformers
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
@ -373,7 +383,7 @@ image.save("squirrel.png")
|
||||||
```
|
```
|
||||||
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
|
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
|
||||||
```python
|
```python
|
||||||
# !pip install diffusers
|
# !pip install diffusers["torch"]
|
||||||
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
||||||
|
|
||||||
model_id = "google/ddpm-celebahq-256"
|
model_id = "google/ddpm-celebahq-256"
|
||||||
|
|
|
@ -12,9 +12,12 @@ specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
# Installation
|
# Installation
|
||||||
|
|
||||||
Install Diffusers for with PyTorch. Support for other libraries will come in the future
|
Install 🤗 Diffusers for whichever deep learning library you’re working with.
|
||||||
|
|
||||||
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
|
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using:
|
||||||
|
|
||||||
|
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||||
|
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||||
|
|
||||||
## Install with pip
|
## Install with pip
|
||||||
|
|
||||||
|
@ -36,12 +39,30 @@ source .env/bin/activate
|
||||||
|
|
||||||
Now you're ready to install 🤗 Diffusers with the following command:
|
Now you're ready to install 🤗 Diffusers with the following command:
|
||||||
|
|
||||||
|
**For PyTorch**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install diffusers
|
pip install diffusers["torch"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**For Flax**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install diffusers["flax"]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Install from source
|
## Install from source
|
||||||
|
|
||||||
|
Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed.
|
||||||
|
|
||||||
|
For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally).
|
||||||
|
|
||||||
|
To install `accelerate`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install accelerate
|
||||||
|
```
|
||||||
|
|
||||||
Install 🤗 Diffusers from source with the following command:
|
Install 🤗 Diffusers from source with the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -67,7 +88,18 @@ Clone the repository and install 🤗 Diffusers with the following commands:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/huggingface/diffusers.git
|
git clone https://github.com/huggingface/diffusers.git
|
||||||
cd diffusers
|
cd diffusers
|
||||||
pip install -e .
|
```
|
||||||
|
|
||||||
|
**For PyTorch**
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -e ".[torch]"
|
||||||
|
```
|
||||||
|
|
||||||
|
**For Flax**
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -e ".[flax]"
|
||||||
```
|
```
|
||||||
|
|
||||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -178,7 +178,6 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
||||||
extras["docs"] = deps_list("hf-doc-builder")
|
extras["docs"] = deps_list("hf-doc-builder")
|
||||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||||
extras["test"] = deps_list(
|
extras["test"] = deps_list(
|
||||||
"accelerate",
|
|
||||||
"datasets",
|
"datasets",
|
||||||
"parameterized",
|
"parameterized",
|
||||||
"pytest",
|
"pytest",
|
||||||
|
@ -188,7 +187,7 @@ extras["test"] = deps_list(
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"transformers"
|
"transformers"
|
||||||
)
|
)
|
||||||
extras["torch"] = deps_list("torch")
|
extras["torch"] = deps_list("torch", "accelerate")
|
||||||
|
|
||||||
if os.name == "nt": # windows
|
if os.name == "nt": # windows
|
||||||
extras["flax"] = [] # jax is not supported on windows
|
extras["flax"] = [] # jax is not supported on windows
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
is_accelerate_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_inflect_available,
|
is_inflect_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
|
@ -16,6 +17,13 @@ from .onnx_utils import OnnxRuntimeModel
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
|
||||||
|
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
|
||||||
|
if is_torch_available() and not is_accelerate_available():
|
||||||
|
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
|
||||||
|
raise ImportError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_utils import ModelMixin
|
from .modeling_utils import ModelMixin
|
||||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||||
|
|
|
@ -21,7 +21,9 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
|
|
||||||
from diffusers.utils import is_accelerate_available
|
import accelerate
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
from accelerate.utils.versions import is_torch_version
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
@ -268,6 +270,19 @@ class ModelMixin(torch.nn.Module):
|
||||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||||
Please refer to the mirror site for more information.
|
Please refer to the mirror site for more information.
|
||||||
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||||
|
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||||
|
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||||
|
same device.
|
||||||
|
|
||||||
|
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||||
|
more information about each option see [designing a device
|
||||||
|
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||||
|
fast_load (`bool`, *optional*, defaults to `True`):
|
||||||
|
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||||
|
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||||
|
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||||
|
this argument will be ignored and the model will be loaded normally.
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
|
@ -296,6 +311,16 @@ class ModelMixin(torch.nn.Module):
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
subfolder = kwargs.pop("subfolder", None)
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
device_map = kwargs.pop("device_map", None)
|
device_map = kwargs.pop("device_map", None)
|
||||||
|
fast_load = kwargs.pop("fast_load", True)
|
||||||
|
|
||||||
|
# Check if we can handle device_map and dispatching the weights
|
||||||
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
|
||||||
|
|
||||||
|
# Fast init is only possible if torch version is >= 1.9.0
|
||||||
|
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
|
||||||
|
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
|
||||||
|
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")
|
||||||
|
|
||||||
user_agent = {
|
user_agent = {
|
||||||
"diffusers": __version__,
|
"diffusers": __version__,
|
||||||
|
@ -378,12 +403,8 @@ class ModelMixin(torch.nn.Module):
|
||||||
|
|
||||||
# restore default dtype
|
# restore default dtype
|
||||||
|
|
||||||
if device_map == "auto":
|
if _INIT_EMPTY_WEIGHTS:
|
||||||
if is_accelerate_available():
|
# Instantiate model with empty weights
|
||||||
import accelerate
|
|
||||||
else:
|
|
||||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
|
||||||
|
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
model, unused_kwargs = cls.from_config(
|
model, unused_kwargs = cls.from_config(
|
||||||
config_path,
|
config_path,
|
||||||
|
@ -400,7 +421,17 @@ class ModelMixin(torch.nn.Module):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
||||||
|
if device_map is None:
|
||||||
|
param_device = "cpu"
|
||||||
|
state_dict = load_state_dict(model_file)
|
||||||
|
# move the parms from meta device to cpu
|
||||||
|
for param_name, param in state_dict.items():
|
||||||
|
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||||
|
else: # else let accelerate handle loading and dispatching.
|
||||||
|
# Load weights and dispatch according to the device_map
|
||||||
|
# by deafult the device_map is None and the weights are loaded on the CPU
|
||||||
|
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
||||||
|
|
||||||
loading_info = {
|
loading_info = {
|
||||||
"missing_keys": [],
|
"missing_keys": [],
|
||||||
|
|
|
@ -380,6 +380,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
provider = kwargs.pop("provider", None)
|
provider = kwargs.pop("provider", None)
|
||||||
sess_options = kwargs.pop("sess_options", None)
|
sess_options = kwargs.pop("sess_options", None)
|
||||||
device_map = kwargs.pop("device_map", None)
|
device_map = kwargs.pop("device_map", None)
|
||||||
|
fast_load = kwargs.pop("fast_load", True)
|
||||||
|
|
||||||
# 1. Download the checkpoints and configs
|
# 1. Download the checkpoints and configs
|
||||||
# use snapshot download here to get it working from from_pretrained
|
# use snapshot download here to get it working from from_pretrained
|
||||||
|
@ -572,6 +573,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_diffusers_model:
|
||||||
|
loading_kwargs["fast_load"] = fast_load
|
||||||
|
|
||||||
|
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||||
|
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
|
||||||
|
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||||
|
if is_transformers_model and device_map is None:
|
||||||
|
loading_kwargs["low_cpu_mem_usage"] = fast_load
|
||||||
|
|
||||||
if is_diffusers_model or is_transformers_model:
|
if is_diffusers_model or is_transformers_model:
|
||||||
loading_kwargs["device_map"] = device_map
|
loading_kwargs["device_map"] = device_map
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,392 @@
|
||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMixin(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKL(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class UNet1DModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class UNet2DConditionModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class UNet2DModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class VQModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_constant_schedule(*args, **kwargs):
|
||||||
|
requires_backends(get_constant_schedule, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_constant_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_cosine_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(*args, **kwargs):
|
||||||
|
requires_backends(get_scheduler, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class DanceDiffusionPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class DDPMPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class KarrasVePipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class LDMPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class PNDMPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class DDPMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class EulerDiscreteScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class IPNDMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class KarrasVeScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class PNDMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
|
||||||
|
class EMAModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "accelerate"])
|
|
@ -28,7 +28,7 @@ class UnetModel1DTests(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_unet_1d_maestro(self):
|
def test_unet_1d_maestro(self):
|
||||||
model_id = "harmonai/maestro-150k"
|
model_id = "harmonai/maestro-150k"
|
||||||
model = UNet1DModel.from_pretrained(model_id, subfolder="unet", device_map="auto")
|
model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
sample_size = 65536
|
sample_size = 65536
|
||||||
|
|
|
@ -125,9 +125,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||||
def test_from_pretrained_accelerate(self):
|
def test_from_pretrained_accelerate(self):
|
||||||
model, _ = UNet2DModel.from_pretrained(
|
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
|
||||||
)
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
image = model(**self.dummy_input).sample
|
image = model(**self.dummy_input).sample
|
||||||
|
|
||||||
|
@ -135,9 +133,8 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||||
def test_from_pretrained_accelerate_wont_change_results(self):
|
def test_from_pretrained_accelerate_wont_change_results(self):
|
||||||
model_accelerate, _ = UNet2DModel.from_pretrained(
|
# by defautl model loading will use accelerate as `fast_load=True`
|
||||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||||
)
|
|
||||||
model_accelerate.to(torch_device)
|
model_accelerate.to(torch_device)
|
||||||
model_accelerate.eval()
|
model_accelerate.eval()
|
||||||
|
|
||||||
|
@ -159,7 +156,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
model_normal_load, _ = UNet2DModel.from_pretrained(
|
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
|
||||||
)
|
)
|
||||||
model_normal_load.to(torch_device)
|
model_normal_load.to(torch_device)
|
||||||
model_normal_load.eval()
|
model_normal_load.eval()
|
||||||
|
@ -173,9 +170,8 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
tracemalloc.start()
|
tracemalloc.start()
|
||||||
model_accelerate, _ = UNet2DModel.from_pretrained(
|
# by defautl model loading will use accelerate as `fast_load=True`
|
||||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||||
)
|
|
||||||
model_accelerate.to(torch_device)
|
model_accelerate.to(torch_device)
|
||||||
model_accelerate.eval()
|
model_accelerate.eval()
|
||||||
_, peak_accelerate = tracemalloc.get_traced_memory()
|
_, peak_accelerate = tracemalloc.get_traced_memory()
|
||||||
|
@ -184,7 +180,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||||
|
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
|
||||||
|
)
|
||||||
model_normal_load.to(torch_device)
|
model_normal_load.to(torch_device)
|
||||||
model_normal_load.eval()
|
model_normal_load.eval()
|
||||||
_, peak_normal = tracemalloc.get_traced_memory()
|
_, peak_normal = tracemalloc.get_traced_memory()
|
||||||
|
@ -348,9 +346,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_from_pretrained_hub(self):
|
def test_from_pretrained_hub(self):
|
||||||
model, loading_info = UNet2DModel.from_pretrained(
|
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
|
||||||
"google/ncsnpp-celebahq-256", output_loading_info=True, device_map="auto"
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||||
|
|
||||||
|
@ -364,7 +360,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_output_pretrained_ve_mid(self):
|
def test_output_pretrained_ve_mid(self):
|
||||||
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", device_map="auto")
|
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -439,7 +435,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||||
|
|
||||||
model = UNet2DConditionModel.from_pretrained(
|
model = UNet2DConditionModel.from_pretrained(
|
||||||
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision, device_map="auto"
|
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
|
||||||
)
|
)
|
||||||
model.to(torch_device).eval()
|
model.to(torch_device).eval()
|
||||||
|
|
||||||
|
|
|
@ -155,7 +155,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||||
|
|
||||||
model = AutoencoderKL.from_pretrained(
|
model = AutoencoderKL.from_pretrained(
|
||||||
model_id, subfolder="vae", torch_dtype=torch_dtype, revision=revision, device_map="auto"
|
model_id,
|
||||||
|
subfolder="vae",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
model.to(torch_device).eval()
|
model.to(torch_device).eval()
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_dance_diffusion(self):
|
def test_dance_diffusion(self):
|
||||||
device = torch_device
|
device = torch_device
|
||||||
|
|
||||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto")
|
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -103,9 +103,7 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_dance_diffusion_fp16(self):
|
def test_dance_diffusion_fp16(self):
|
||||||
device = torch_device
|
device = torch_device
|
||||||
|
|
||||||
pipe = DanceDiffusionPipeline.from_pretrained(
|
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
|
||||||
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
|
|
||||||
)
|
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_ema_bedroom(self):
|
def test_inference_ema_bedroom(self):
|
||||||
model_id = "google/ddpm-ema-bedroom-256"
|
model_id = "google/ddpm-ema-bedroom-256"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
unet = UNet2DModel.from_pretrained(model_id)
|
||||||
scheduler = DDIMScheduler.from_config(model_id)
|
scheduler = DDIMScheduler.from_config(model_id)
|
||||||
|
|
||||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_cifar10(self):
|
def test_inference_cifar10(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
unet = UNet2DModel.from_pretrained(model_id)
|
||||||
scheduler = DDIMScheduler()
|
scheduler = DDIMScheduler()
|
||||||
|
|
||||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
|
@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_cifar10(self):
|
def test_inference_cifar10(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
unet = UNet2DModel.from_pretrained(model_id)
|
||||||
scheduler = DDPMScheduler.from_config(model_id)
|
scheduler = DDPMScheduler.from_config(model_id)
|
||||||
|
|
||||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
|
@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference(self):
|
def test_inference(self):
|
||||||
model_id = "google/ncsnpp-celebahq-256"
|
model_id = "google/ncsnpp-celebahq-256"
|
||||||
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
model = UNet2DModel.from_pretrained(model_id)
|
||||||
scheduler = KarrasVeScheduler()
|
scheduler = KarrasVeScheduler()
|
||||||
|
|
||||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||||
|
|
|
@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
@require_torch
|
@require_torch
|
||||||
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_text2img(self):
|
def test_inference_text2img(self):
|
||||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
|
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||||
ldm.to(torch_device)
|
ldm.to(torch_device)
|
||||||
ldm.set_progress_bar_config(disable=None)
|
ldm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_inference_text2img_fast(self):
|
def test_inference_text2img_fast(self):
|
||||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
|
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||||
ldm.to(torch_device)
|
ldm.to(torch_device)
|
||||||
ldm.set_progress_bar_config(disable=None)
|
ldm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,7 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_cifar10(self):
|
def test_inference_cifar10(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
unet = UNet2DModel.from_pretrained(model_id)
|
||||||
scheduler = PNDMScheduler()
|
scheduler = PNDMScheduler()
|
||||||
|
|
||||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
|
@ -72,7 +72,7 @@ class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference(self):
|
def test_inference(self):
|
||||||
model_id = "google/ncsnpp-church-256"
|
model_id = "google/ncsnpp-church-256"
|
||||||
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
model = UNet2DModel.from_pretrained(model_id)
|
||||||
|
|
||||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||||
|
|
||||||
|
|
|
@ -631,7 +631,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_stable_diffusion(self):
|
def test_stable_diffusion(self):
|
||||||
# make sure here that pndm scheduler skips prk
|
# make sure here that pndm scheduler skips prk
|
||||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
|
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
|
||||||
sd_pipe = sd_pipe.to(torch_device)
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -653,9 +653,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_stable_diffusion_fast_ddim(self):
|
def test_stable_diffusion_fast_ddim(self):
|
||||||
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
|
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
|
||||||
|
|
||||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
|
||||||
"CompVis/stable-diffusion-v1-1", scheduler=scheduler, device_map="auto"
|
|
||||||
)
|
|
||||||
sd_pipe = sd_pipe.to(torch_device)
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -674,7 +672,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_lms_stable_diffusion_pipeline(self):
|
def test_lms_stable_diffusion_pipeline(self):
|
||||||
model_id = "CompVis/stable-diffusion-v1-1"
|
model_id = "CompVis/stable-diffusion-v1-1"
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, device_map="auto").to(torch_device)
|
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||||
pipe.scheduler = scheduler
|
pipe.scheduler = scheduler
|
||||||
|
@ -693,9 +691,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_stable_diffusion_memory_chunking(self):
|
def test_stable_diffusion_memory_chunking(self):
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
model_id = "CompVis/stable-diffusion-v1-4"
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||||
model_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
|
||||||
)
|
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -732,9 +728,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_stable_diffusion_text2img_pipeline_fp16(self):
|
def test_stable_diffusion_text2img_pipeline_fp16(self):
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
model_id = "CompVis/stable-diffusion-v1-4"
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||||
model_id, revision="fp16", device_map="auto", torch_dtype=torch.float16
|
|
||||||
)
|
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -767,11 +761,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||||
|
|
||||||
model_id = "CompVis/stable-diffusion-v1-4"
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None)
|
||||||
model_id,
|
|
||||||
safety_checker=None,
|
|
||||||
device_map="auto",
|
|
||||||
)
|
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
@ -812,7 +802,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
test_callback_fn.has_been_called = False
|
test_callback_fn.has_been_called = False
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -833,23 +823,23 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
assert test_callback_fn.has_been_called
|
assert test_callback_fn.has_been_called
|
||||||
assert number_of_steps == 51
|
assert number_of_steps == 51
|
||||||
|
|
||||||
def test_stable_diffusion_accelerate_auto_device(self):
|
def test_stable_diffusion_fast_load(self):
|
||||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
pipeline_fast_load = StableDiffusionPipeline.from_pretrained(
|
||||||
pipeline_id, revision="fp16", torch_dtype=torch.float16
|
pipeline_id, revision="fp16", torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
pipeline_normal_load.to(torch_device)
|
pipeline_fast_load.to(torch_device)
|
||||||
normal_load_time = time.time() - start_time
|
fast_load_time = time.time() - start_time
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
_ = StableDiffusionPipeline.from_pretrained(
|
_ = StableDiffusionPipeline.from_pretrained(
|
||||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False
|
||||||
)
|
)
|
||||||
meta_device_load_time = time.time() - start_time
|
normal_load_time = time.time() - start_time
|
||||||
|
|
||||||
assert 2 * meta_device_load_time < normal_load_time
|
assert 2 * fast_load_time < normal_load_time
|
||||||
|
|
||||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||||
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
|
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
|
||||||
|
|
|
@ -488,7 +488,6 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
device_map="auto",
|
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -529,7 +528,6 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
model_id,
|
model_id,
|
||||||
scheduler=lms,
|
scheduler=lms,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
device_map="auto",
|
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -581,7 +579,9 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
init_image = init_image.resize((768, 512))
|
init_image = init_image.resize((768, 512))
|
||||||
|
|
||||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
"CompVis/stable-diffusion-v1-4",
|
||||||
|
revision="fp16",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
|
@ -284,11 +284,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
model_id = "runwayml/stable-diffusion-inpainting"
|
model_id = "runwayml/stable-diffusion-inpainting"
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
|
||||||
model_id,
|
|
||||||
safety_checker=None,
|
|
||||||
device_map="auto",
|
|
||||||
)
|
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
@ -328,7 +324,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
revision="fp16",
|
revision="fp16",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
device_map="auto",
|
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -365,9 +360,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
model_id = "runwayml/stable-diffusion-inpainting"
|
model_id = "runwayml/stable-diffusion-inpainting"
|
||||||
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
|
||||||
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
|
|
||||||
)
|
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
|
|
@ -364,11 +364,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
model_id = "CompVis/stable-diffusion-v1-4"
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
|
||||||
model_id,
|
|
||||||
safety_checker=None,
|
|
||||||
device_map="auto",
|
|
||||||
)
|
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
@ -411,7 +407,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||||
model_id,
|
model_id,
|
||||||
scheduler=lms,
|
scheduler=lms,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
device_map="auto",
|
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -468,7 +463,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
|
@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase):
|
||||||
def test_read_init(self):
|
def test_read_init(self):
|
||||||
objects = read_init()
|
objects = read_init()
|
||||||
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
|
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
|
||||||
self.assertIn("torch", objects)
|
self.assertIn("torch_and_accelerate", objects)
|
||||||
self.assertIn("torch_and_transformers", objects)
|
self.assertIn("torch_and_transformers", objects)
|
||||||
self.assertIn("flax_and_transformers", objects)
|
self.assertIn("flax_and_transformers", objects)
|
||||||
self.assertIn("torch_and_transformers_and_onnx", objects)
|
self.assertIn("torch_and_transformers_and_onnx", objects)
|
||||||
|
|
||||||
# Likewise, we can't assert on the exact content of a key
|
# Likewise, we can't assert on the exact content of a key
|
||||||
self.assertIn("UNet2DModel", objects["torch"])
|
self.assertIn("UNet2DModel", objects["torch_and_accelerate"])
|
||||||
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
|
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
|
||||||
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
|
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
|
||||||
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
|
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
|
||||||
|
|
|
@ -128,7 +128,7 @@ class CustomPipelineTests(unittest.TestCase):
|
||||||
def test_load_pipeline_from_git(self):
|
def test_load_pipeline_from_git(self):
|
||||||
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||||
|
|
||||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto")
|
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
||||||
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
|
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
|
||||||
|
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
@ -138,7 +138,6 @@ class CustomPipelineTests(unittest.TestCase):
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
revision="fp16",
|
revision="fp16",
|
||||||
device_map="auto",
|
|
||||||
)
|
)
|
||||||
pipeline.enable_attention_slicing()
|
pipeline.enable_attention_slicing()
|
||||||
pipeline = pipeline.to(torch_device)
|
pipeline = pipeline.to(torch_device)
|
||||||
|
@ -333,9 +332,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
def test_smart_download(self):
|
def test_smart_download(self):
|
||||||
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
_ = DiffusionPipeline.from_pretrained(
|
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
|
||||||
model_id, cache_dir=tmpdirname, force_download=True, device_map="auto"
|
|
||||||
)
|
|
||||||
local_repo_name = "--".join(["models"] + model_id.split("/"))
|
local_repo_name = "--".join(["models"] + model_id.split("/"))
|
||||||
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
|
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
|
||||||
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
|
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
|
||||||
|
@ -359,7 +356,10 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
with CaptureLogger(logger) as cap_logger:
|
with CaptureLogger(logger) as cap_logger:
|
||||||
DiffusionPipeline.from_pretrained(
|
DiffusionPipeline.from_pretrained(
|
||||||
model_id, not_used=True, cache_dir=tmpdirname, force_download=True, device_map="auto"
|
model_id,
|
||||||
|
not_used=True,
|
||||||
|
cache_dir=tmpdirname,
|
||||||
|
force_download=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
||||||
|
@ -383,7 +383,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
ddpm.save_pretrained(tmpdirname)
|
ddpm.save_pretrained(tmpdirname)
|
||||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname, device_map="auto")
|
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
|
||||||
new_ddpm.to(torch_device)
|
new_ddpm.to(torch_device)
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
|
@ -399,11 +399,11 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||||
|
|
||||||
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
|
||||||
ddpm = ddpm.to(torch_device)
|
ddpm = ddpm.to(torch_device)
|
||||||
ddpm.set_progress_bar_config(disable=None)
|
ddpm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
|
||||||
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
||||||
ddpm_from_hub.set_progress_bar_config(disable=None)
|
ddpm_from_hub.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -421,14 +421,12 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||||
|
|
||||||
# pass unet into DiffusionPipeline
|
# pass unet into DiffusionPipeline
|
||||||
unet = UNet2DModel.from_pretrained(model_path, device_map="auto")
|
unet = UNet2DModel.from_pretrained(model_path)
|
||||||
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(
|
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
|
||||||
model_path, unet=unet, scheduler=scheduler, device_map="auto"
|
|
||||||
)
|
|
||||||
ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
|
ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
|
||||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
|
||||||
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
||||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -443,7 +441,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
def test_output_format(self):
|
def test_output_format(self):
|
||||||
model_path = "google/ddpm-cifar10-32"
|
model_path = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
pipe = DDIMPipeline.from_pretrained(model_path, device_map="auto")
|
pipe = DDIMPipeline.from_pretrained(model_path)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -467,7 +465,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
def test_ddpm_ddim_equality(self, seed):
|
def test_ddpm_ddim_equality(self, seed):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
unet = UNet2DModel.from_pretrained(model_id)
|
||||||
ddpm_scheduler = DDPMScheduler()
|
ddpm_scheduler = DDPMScheduler()
|
||||||
ddim_scheduler = DDIMScheduler()
|
ddim_scheduler = DDIMScheduler()
|
||||||
|
|
||||||
|
@ -498,7 +496,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
def test_ddpm_ddim_equality_batched(self, seed):
|
def test_ddpm_ddim_equality_batched(self, seed):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
unet = UNet2DModel.from_pretrained(model_id)
|
||||||
ddpm_scheduler = DDPMScheduler()
|
ddpm_scheduler = DDPMScheduler()
|
||||||
ddim_scheduler = DDIMScheduler()
|
ddim_scheduler = DDIMScheduler()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue