default fast model loading 🔥 (#1115)

* make accelerate hard dep

* default fast init

* move params to cpu when device map is None

* handle device_map=None

* handle torch < 1.9

* remove device_map="auto"

* style

* add accelerate in torch extra

* remove accelerate from extras["test"]

* raise an error if torch is available but not accelerate

* update installation docs

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* improve defautl loading speed even further, allow disabling fats loading

* address review comments

* adapt the tests

* fix test_stable_diffusion_fast_load

* fix test_read_init

* temp fix for dummy checks

* Trigger Build

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <anton@huggingface.co>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Suraj Patil 2022-11-03 17:25:57 +01:00 committed by GitHub
parent ef2ea33c3b
commit 7482178162
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 564 additions and 109 deletions

View File

@ -27,10 +27,12 @@ More precisely, 🤗 Diffusers offers:
## Installation
### For PyTorch
**With `pip`**
```bash
pip install --upgrade diffusers
pip install --upgrade diffusers[torch]
```
**With `conda`**
@ -39,6 +41,14 @@ pip install --upgrade diffusers
conda install -c conda-forge diffusers
```
### For Flax
**With `pip`**
```bash
pip install --upgrade diffusers[flax]
```
**Apple Silicon (M1/M2) support**
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
@ -354,7 +364,7 @@ There are many ways to try running Diffusers! Here we outline code-focused tools
If you want to run the code yourself 💻, you can try out:
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
```python
# !pip install diffusers transformers
# !pip install diffusers["torch"] transformers
from diffusers import DiffusionPipeline
device = "cuda"
@ -373,7 +383,7 @@ image.save("squirrel.png")
```
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
```python
# !pip install diffusers
# !pip install diffusers["torch"]
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
model_id = "google/ddpm-celebahq-256"

View File

@ -12,9 +12,12 @@ specific language governing permissions and limitations under the License.
# Installation
Install Diffusers for with PyTorch. Support for other libraries will come in the future
Install 🤗 Diffusers for whichever deep learning library youre working with.
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using:
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
## Install with pip
@ -36,12 +39,30 @@ source .env/bin/activate
Now you're ready to install 🤗 Diffusers with the following command:
**For PyTorch**
```bash
pip install diffusers
pip install diffusers["torch"]
```
**For Flax**
```bash
pip install diffusers["flax"]
```
## Install from source
Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed.
For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally).
To install `accelerate`
```bash
pip install accelerate
```
Install 🤗 Diffusers from source with the following command:
```bash
@ -67,7 +88,18 @@ Clone the repository and install 🤗 Diffusers with the following commands:
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install -e .
```
**For PyTorch**
```
pip install -e ".[torch]"
```
**For Flax**
```
pip install -e ".[flax]"
```
These commands will link the folder you cloned the repository to and your Python library paths.

View File

@ -178,7 +178,6 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"accelerate",
"datasets",
"parameterized",
"pytest",
@ -188,7 +187,7 @@ extras["test"] = deps_list(
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch")
extras["torch"] = deps_list("torch", "accelerate")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows

View File

@ -1,4 +1,5 @@
from .utils import (
is_accelerate_available,
is_flax_available,
is_inflect_available,
is_onnx_available,
@ -16,6 +17,13 @@ from .onnx_utils import OnnxRuntimeModel
from .utils import logging
# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
if is_torch_available() and not is_accelerate_available():
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
raise ImportError(error_msg)
if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel

View File

@ -21,7 +21,9 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, device
from diffusers.utils import is_accelerate_available
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
@ -268,6 +270,19 @@ class ModelMixin(torch.nn.Module):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
fast_load (`bool`, *optional*, defaults to `True`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
this argument will be ignored and the model will be loaded normally.
<Tip>
@ -296,6 +311,16 @@ class ModelMixin(torch.nn.Module):
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True)
# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
# Fast init is only possible if torch version is >= 1.9.0
_INIT_EMPTY_WEIGHTS = fast_load or device_map is not None
if _INIT_EMPTY_WEIGHTS and not is_torch_version(">=", "1.9.0"):
logger.warn("Loading with `fast_load` requires torch >= 1.9.0. Falling back to normal loading.")
user_agent = {
"diffusers": __version__,
@ -378,12 +403,8 @@ class ModelMixin(torch.nn.Module):
# restore default dtype
if device_map == "auto":
if is_accelerate_available():
import accelerate
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
if _INIT_EMPTY_WEIGHTS:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config(
config_path,
@ -400,7 +421,17 @@ class ModelMixin(torch.nn.Module):
**kwargs,
)
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file)
# move the parms from meta device to cpu
for param_name, param in state_dict.items():
set_module_tensor_to_device(model, param_name, param_device, value=param)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by deafult the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
loading_info = {
"missing_keys": [],

View File

@ -380,6 +380,7 @@ class DiffusionPipeline(ConfigMixin):
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
fast_load = kwargs.pop("fast_load", True)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@ -572,6 +573,15 @@ class DiffusionPipeline(ConfigMixin):
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
)
if is_diffusers_model:
loading_kwargs["fast_load"] = fast_load
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=fast_load` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_transformers_model and device_map is None:
loading_kwargs["low_cpu_mem_usage"] = fast_load
if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map

View File

@ -0,0 +1,392 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class ModelMixin(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet2DConditionModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class UNet2DModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class VQModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch", "accelerate"])
def get_constant_schedule_with_warmup(*args, **kwargs):
requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"])
def get_cosine_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"])
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"])
def get_linear_schedule_with_warmup(*args, **kwargs):
requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"])
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"])
def get_scheduler(*args, **kwargs):
requires_backends(get_scheduler, ["torch", "accelerate"])
class DiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DanceDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDIMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDPMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class KarrasVePipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class LDMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class DDPMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EulerDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class IPNDMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class KarrasVeScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
class EMAModel(metaclass=DummyObject):
_backends = ["torch", "accelerate"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "accelerate"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "accelerate"])

View File

@ -28,7 +28,7 @@ class UnetModel1DTests(unittest.TestCase):
@slow
def test_unet_1d_maestro(self):
model_id = "harmonai/maestro-150k"
model = UNet1DModel.from_pretrained(model_id, subfolder="unet", device_map="auto")
model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
model.to(torch_device)
sample_size = 65536

View File

@ -125,9 +125,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model.to(torch_device)
image = model(**self.dummy_input).sample
@ -135,9 +133,8 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate_wont_change_results(self):
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
# by defautl model loading will use accelerate as `fast_load=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_accelerate.to(torch_device)
model_accelerate.eval()
@ -159,7 +156,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
)
model_normal_load.to(torch_device)
model_normal_load.eval()
@ -173,9 +170,8 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect()
tracemalloc.start()
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
# by defautl model loading will use accelerate as `fast_load=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_accelerate.to(torch_device)
model_accelerate.eval()
_, peak_accelerate = tracemalloc.get_traced_memory()
@ -184,7 +180,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
torch.cuda.empty_cache()
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, fast_init=False
)
model_normal_load.to(torch_device)
model_normal_load.eval()
_, peak_normal = tracemalloc.get_traced_memory()
@ -348,9 +346,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained(
"google/ncsnpp-celebahq-256", output_loading_info=True, device_map="auto"
)
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@ -364,7 +360,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@slow
def test_output_pretrained_ve_mid(self):
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", device_map="auto")
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
model.to(torch_device)
torch.manual_seed(0)
@ -439,7 +435,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
torch_dtype = torch.float16 if fp16 else torch.float32
model = UNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision, device_map="auto"
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
)
model.to(torch_device).eval()

View File

@ -155,7 +155,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
torch_dtype = torch.float16 if fp16 else torch.float32
model = AutoencoderKL.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch_dtype, revision=revision, device_map="auto"
model_id,
subfolder="vae",
torch_dtype=torch_dtype,
revision=revision,
)
model.to(torch_device).eval()

View File

@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
def test_dance_diffusion(self):
device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto")
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
@ -103,9 +103,7 @@ class PipelineIntegrationTests(unittest.TestCase):
def test_dance_diffusion_fp16(self):
device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained(
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
)
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)

View File

@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_ema_bedroom(self):
model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler()
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)

View File

@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)

View File

@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class KarrasVePipelineIntegrationTests(unittest.TestCase):
def test_inference(self):
model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
model = UNet2DModel.from_pretrained(model_id)
scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)

View File

@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@require_torch
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
def test_inference_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)

View File

@ -71,7 +71,7 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
unet = UNet2DModel.from_pretrained(model_id)
scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)

View File

@ -72,7 +72,7 @@ class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
def test_inference(self):
model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
model = UNet2DModel.from_pretrained(model_id)
scheduler = ScoreSdeVeScheduler.from_config(model_id)

View File

@ -631,7 +631,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
@ -653,9 +653,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_fast_ddim(self):
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-1", scheduler=scheduler, device_map="auto"
)
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
@ -674,7 +672,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, device_map="auto").to(torch_device)
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe.scheduler = scheduler
@ -693,9 +691,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_memory_chunking(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
)
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@ -732,9 +728,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", device_map="auto", torch_dtype=torch.float16
)
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@ -767,11 +761,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
safety_checker=None,
device_map="auto",
)
pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
@ -812,7 +802,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
test_callback_fn.has_been_called = False
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@ -833,23 +823,23 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert test_callback_fn.has_been_called
assert number_of_steps == 51
def test_stable_diffusion_accelerate_auto_device(self):
def test_stable_diffusion_fast_load(self):
pipeline_id = "CompVis/stable-diffusion-v1-4"
start_time = time.time()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_fast_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16
)
pipeline_normal_load.to(torch_device)
normal_load_time = time.time() - start_time
pipeline_fast_load.to(torch_device)
fast_load_time = time.time() - start_time
start_time = time.time()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, fast_load=False
)
meta_device_load_time = time.time() - start_time
normal_load_time = time.time() - start_time
assert 2 * meta_device_load_time < normal_load_time
assert 2 * fast_load_time < normal_load_time
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):

View File

@ -488,7 +488,6 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@ -529,7 +528,6 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
model_id,
scheduler=lms,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@ -581,7 +579,9 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512))
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

View File

@ -284,11 +284,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
model_id = "runwayml/stable-diffusion-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=None,
device_map="auto",
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
@ -328,7 +324,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
revision="fp16",
torch_dtype=torch.float16,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@ -365,9 +360,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()

View File

@ -364,11 +364,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=None,
device_map="auto",
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
@ -411,7 +407,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
model_id,
scheduler=lms,
safety_checker=None,
device_map="auto",
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@ -468,7 +463,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

View File

@ -52,13 +52,13 @@ class CheckDummiesTester(unittest.TestCase):
def test_read_init(self):
objects = read_init()
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
self.assertIn("torch", objects)
self.assertIn("torch_and_accelerate", objects)
self.assertIn("torch_and_transformers", objects)
self.assertIn("flax_and_transformers", objects)
self.assertIn("torch_and_transformers_and_onnx", objects)
# Likewise, we can't assert on the exact content of a key
self.assertIn("UNet2DModel", objects["torch"])
self.assertIn("UNet2DModel", objects["torch_and_accelerate"])
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])

View File

@ -128,7 +128,7 @@ class CustomPipelineTests(unittest.TestCase):
def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto")
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
pipeline = DiffusionPipeline.from_pretrained(
@ -138,7 +138,6 @@ class CustomPipelineTests(unittest.TestCase):
feature_extractor=feature_extractor,
torch_dtype=torch.float16,
revision="fp16",
device_map="auto",
)
pipeline.enable_attention_slicing()
pipeline = pipeline.to(torch_device)
@ -333,9 +332,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_smart_download(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
with tempfile.TemporaryDirectory() as tmpdirname:
_ = DiffusionPipeline.from_pretrained(
model_id, cache_dir=tmpdirname, force_download=True, device_map="auto"
)
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
local_repo_name = "--".join(["models"] + model_id.split("/"))
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
@ -359,7 +356,10 @@ class PipelineSlowTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
with CaptureLogger(logger) as cap_logger:
DiffusionPipeline.from_pretrained(
model_id, not_used=True, cache_dir=tmpdirname, force_download=True, device_map="auto"
model_id,
not_used=True,
cache_dir=tmpdirname,
force_download=True,
)
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
@ -383,7 +383,7 @@ class PipelineSlowTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname, device_map="auto")
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.manual_seed(0)
@ -399,11 +399,11 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm = ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None)
@ -421,14 +421,12 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10)
# pass unet into DiffusionPipeline
unet = UNet2DModel.from_pretrained(model_path, device_map="auto")
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(
model_path, unet=unet, scheduler=scheduler, device_map="auto"
)
unet = UNet2DModel.from_pretrained(model_path)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
@ -443,7 +441,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path, device_map="auto")
pipe = DDIMPipeline.from_pretrained(model_path)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@ -467,7 +465,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
@ -498,7 +496,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_ddpm_ddim_equality_batched(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()