Merge branch 'main' of github.com:huggingface/diffusers

This commit is contained in:
anton-l 2022-06-09 12:43:03 +02:00
commit ae73d95e41
8 changed files with 632 additions and 225 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2021 The HuggingFace Team. All rights reserved. # Copyright 2022 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -52,11 +52,11 @@ To create the package for pypi.
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
Check that you can install it in a virtualenv by running: Check that you can install it in a virtualenv by running:
pip install -i https://testpypi.python.org/pypi transformers pip install -i https://testpypi.python.org/pypi diffusers
Check you can run the following commands: Check you can run the following commands:
python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))" python -c "from diffusers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from transformers import *" python -c "from diffusers import *"
9. Upload the final version to actual pypi: 9. Upload the final version to actual pypi:
twine upload dist/* -r pypi twine upload dist/* -r pypi
@ -77,36 +77,21 @@ from setuptools import find_packages, setup
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [ _deps = [
"Pillow", "Pillow",
"accelerate>=0.9.0",
"black~=22.0,>=22.3", "black~=22.0,>=22.3",
"codecarbon==1.2.0", "filelock",
"dataclasses", "flake8>=3.8.3",
"datasets", "huggingface-hub",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.1.0,<1.0",
"importlib_metadata",
"isort>=5.5.4", "isort>=5.5.4",
"numpy>=1.17", "numpy",
"pytest", "pytest",
"pytest-timeout",
"pytest-xdist",
"python>=3.7.0",
"regex!=2019.12.17",
"requests", "requests",
"sagemaker>=2.31.0",
"tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch>=1.4", "torch>=1.4",
"torchaudio", "torchvision",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
"uvicorn",
] ]
# this is a lookup table with items like: # this is a lookup table with items like:
# #
# tokenizers: "tokenizers==0.9.4" # tokenizers: "huggingface-hub==0.8.0"
# packaging: "packaging" # packaging: "packaging"
# #
# some of the values are versioned whereas others aren't. # some of the values are versioned whereas others aren't.
@ -176,15 +161,17 @@ extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
extras["docs"] = [] extras["docs"] = []
extras["test"] = [ extras["test"] = [
"pytest", "pytest",
"pytest-xdist",
"pytest-subtests",
"datasets",
"transformers",
] ]
extras["dev"] = extras["quality"] + extras["test"] extras["dev"] = extras["quality"] + extras["test"]
extras["sagemaker"] = [ install_requires = [
"sagemaker", # boto3 is a required package in sagemaker deps["filelock"],
deps["huggingface-hub"],
deps["numpy"],
deps["requests"],
deps["torch"],
deps["torchvision"],
deps["Pillow"],
] ]
setup( setup(
@ -201,7 +188,7 @@ setup(
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages("src"), packages=find_packages("src"),
python_requires=">=3.6.0", python_requires=">=3.6.0",
install_requires=["numpy>=1.17", "packaging>=20.0", "pyyaml", "torch>=1.4.0"], install_requires=install_requires,
extras_require=extras, extras_require=extras,
classifiers=[ classifiers=[
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",

View File

@ -24,18 +24,19 @@ import re
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from requests import HTTPError from requests import HTTPError
from transformers.utils import ( from huggingface_hub import hf_hub_download
from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DIFFUSERS_CACHE,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
cached_path,
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging, logging,
) )
from . import __version__ from . import __version__
@ -56,6 +57,8 @@ class ConfigMixin:
if self.config_name is None: if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__ kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = __version__
for key, value in kwargs.items(): for key, value in kwargs.items():
try: try:
setattr(self, key, value) setattr(self, key, value)
@ -90,11 +93,26 @@ class ConfigMixin:
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}") logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
model = cls(**init_dict)
if return_unused_kwargs:
return model, unused_kwargs
else:
return model
@classmethod @classmethod
def get_config_dict( def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
@ -104,85 +122,83 @@ class ConfigMixin:
user_agent = {"file_type": "config"} user_agent = {"file_type": "config"}
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
configuration_file = cls.config_name
if os.path.isdir(pretrained_model_name_or_path): if cls.config_name is None:
config_file = os.path.join(pretrained_model_name_or_path, configuration_file) raise ValueError(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)
if os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
else: else:
config_file = hf_bucket_url( raise EnvironmentError(
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
)
else:
try:
# Load from URL or cache if already cached
config_file = hf_hub_download(
pretrained_model_name_or_path,
filename=cls.config_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file"
) )
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {configuration_file} file"
)
try: try:
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file) config_dict = cls._dict_from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError( raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." f"It looks like the config file at '{config_file}' is not a valid JSON file."
) )
if resolved_config_file == config_file:
logger.info(f"loading configuration file {config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict return config_dict
@classmethod @classmethod
@ -208,19 +224,6 @@ class ConfigMixin:
return init_dict, unused_kwargs return init_dict, unused_kwargs
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
model = cls(**init_dict)
if return_unused_kwargs:
return model, unused_kwargs
else:
return model
@classmethod @classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, "r", encoding="utf-8") as reader:
@ -233,18 +236,9 @@ class ConfigMixin:
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
def to_dict(self) -> Dict[str, Any]: @property
""" def config(self) -> Dict[str, Any]:
Serializes this instance to a Python dictionary. output = copy.deepcopy(self._dict_to_save)
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
# Diffusion version when serializing the model
output["diffusers_version"] = __version__
return output return output
def to_json_string(self) -> str: def to_json_string(self) -> str:

View File

@ -22,16 +22,8 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info from huggingface_hub import cached_download
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
from transformers.utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
logging,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -219,7 +211,7 @@ def get_cached_module_file(
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_module_file = cached_path( resolved_module_file = cached_download(
module_file_or_url, module_file_or_url,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
@ -237,7 +229,7 @@ def get_cached_module_file(
modules_needed = check_imports(resolved_module_file) modules_needed = check_imports(resolved_module_file)
# Now we move the module inside our cached dynamic modules. # Now we move the module inside our cached dynamic modules.
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
create_dynamic_module(full_submodule) create_dynamic_module(full_submodule)
submodule_path = Path(HF_MODULES_CACHE) / full_submodule submodule_path = Path(HF_MODULES_CACHE) / full_submodule
# We always copy local files (we could hash the file to see if there was a change, and give them the name of # We always copy local files (we could hash the file to see if there was a change, and give them the name of

View File

@ -21,18 +21,15 @@ import torch
from torch import Tensor, device from torch import Tensor, device
from requests import HTTPError from requests import HTTPError
from huggingface_hub import hf_hub_download
# CHANGE to diffusers.utils from .utils import (
from transformers.utils import (
CONFIG_NAME, CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
cached_path,
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging, logging,
) )
@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
</Tip> </Tip>
""" """
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path config_path = pretrained_model_name_or_path
model, unused_kwargs = cls.from_config( model, unused_kwargs = cls.from_config(
@ -353,79 +345,67 @@ class ModelMixin(torch.nn.Module):
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else: else:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else: else:
filename = WEIGHTS_NAME try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
archive_file = hf_bucket_url( except RepositoryNotFoundError:
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror raise EnvironmentError(
) f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}"
)
try: # restore default dtype
# Load from URL or cache if already cached state_dict = load_state_dict(model_file)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {filename}.")
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}"
)
if resolved_archive_file == archive_file:
logger.info(f"loading weights file {archive_file}")
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
# restore default dtype
state_dict = load_state_dict(resolved_archive_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model, model,
state_dict, state_dict,
resolved_archive_file, model_file,
pretrained_model_name_or_path, pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes, ignore_mismatched_sizes=ignore_mismatched_sizes,
) )

View File

@ -20,8 +20,7 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils from .utils import logging, DIFFUSERS_CACHE
from transformers.utils import logging
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
@ -76,11 +75,12 @@ class DiffusionPipeline(ConfigMixin):
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory) self.save_config(save_directory)
model_index_dict = self._dict_to_save model_index_dict = self.config
model_index_dict.pop("_class_name") model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module") model_index_dict.pop("_module")
for name, (library_name, class_name) in self._dict_to_save.items(): for name, (library_name, class_name) in model_index_dict.items():
importable_classes = LOADABLE_CLASSES[library_name] importable_classes = LOADABLE_CLASSES[library_name]
# TODO: Suraj # TODO: Suraj
@ -101,14 +101,36 @@ class DiffusionPipeline(ConfigMixin):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Add docstrings
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
cached_folder = snapshot_download(pretrained_model_name_or_path) cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
output_loading_info=output_loading_info,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder) config_dict = cls.get_config_dict(cached_folder)
module = config_dict["_module"]
class_name_ = config_dict["_class_name"]
module_candidate = config_dict["_module"] module_candidate = config_dict["_module"]
module_candidate_name = module_candidate.replace(".py", "") module_candidate_name = module_candidate.replace(".py", "")
@ -126,13 +148,12 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {} init_kwargs = {}
# get all importable classes to get the load method name for custom models/components # get all importable classes to get the load method name for custom models/components
# here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers # here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers
all_importable_classes = {} all_importable_classes = {}
for library in LOADABLE_CLASSES: for library in LOADABLE_CLASSES:
all_importable_classes.update(LOADABLE_CLASSES[library]) all_importable_classes.update(LOADABLE_CLASSES[library])
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
# if the model is not in diffusers or transformers, we need to load it from the hub # if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin # assumes that it's a subclass of ModelMixin
if library_name == module_candidate_name: if library_name == module_candidate_name:

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from requests.exceptions import HTTPError
import os
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "diffusers")
CONFIG_NAME = "config.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""
class EntryNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""

View File

@ -0,0 +1,344 @@
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities."""
import logging
import os
import sys
import threading
from logging import CRITICAL # NOQA
from logging import DEBUG # NOQA
from logging import ERROR # NOQA
from logging import FATAL # NOQA
from logging import INFO # NOQA
from logging import NOTSET # NOQA
from logging import WARN # NOQA
from logging import WARNING # NOQA
from typing import Optional
from tqdm import auto as tqdm_lib
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
_default_log_level = logging.WARNING
_tqdm_active = True
def _get_default_logging_level():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
"""
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
global _default_handler
with _lock:
if _default_handler:
# This library has already configured the library root logger.
return
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
_default_handler.flush = sys.stderr.flush
# Apply our default configuration to the library root logger.
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def _reset_library_root_logger() -> None:
global _default_handler
with _lock:
if not _default_handler:
return
library_root_logger = _get_library_root_logger()
library_root_logger.removeHandler(_default_handler)
library_root_logger.setLevel(logging.NOTSET)
_default_handler = None
def get_log_levels_dict():
return log_levels
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Return a logger with the specified name.
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def get_verbosity() -> int:
"""
Return the current level for the 🤗 Transformers's root logger as an int.
Returns:
`int`: The logging level.
<Tip>
🤗 Transformers has following logging levels:
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- 40: `diffusers.logging.ERROR`
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- 20: `diffusers.logging.INFO`
- 10: `diffusers.logging.DEBUG`
</Tip>"""
_configure_library_root_logger()
return _get_library_root_logger().getEffectiveLevel()
def set_verbosity(verbosity: int) -> None:
"""
Set the verbosity level for the 🤗 Transformers's root logger.
Args:
verbosity (`int`):
Logging level, e.g., one of:
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- `diffusers.logging.ERROR`
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- `diffusers.logging.INFO`
- `diffusers.logging.DEBUG`
"""
_configure_library_root_logger()
_get_library_root_logger().setLevel(verbosity)
def set_verbosity_info():
"""Set the verbosity to the `INFO` level."""
return set_verbosity(INFO)
def set_verbosity_warning():
"""Set the verbosity to the `WARNING` level."""
return set_verbosity(WARNING)
def set_verbosity_debug():
"""Set the verbosity to the `DEBUG` level."""
return set_verbosity(DEBUG)
def set_verbosity_error():
"""Set the verbosity to the `ERROR` level."""
return set_verbosity(ERROR)
def disable_default_handler() -> None:
"""Disable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().removeHandler(_default_handler)
def enable_default_handler() -> None:
"""Enable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert _default_handler is not None
_get_library_root_logger().addHandler(_default_handler)
def add_handler(handler: logging.Handler) -> None:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None and handler not in _get_library_root_logger().handlers
_get_library_root_logger().removeHandler(handler)
def disable_propagation() -> None:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = False
def enable_propagation() -> None:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger()
_get_library_root_logger().propagate = True
def enable_explicit_format() -> None:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
```
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
```
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
handler.setFormatter(formatter)
def reset_format() -> None:
"""
Resets the formatting for HuggingFace Transformers's loggers.
All handlers currently bound to the root logger are affected by this method.
"""
handlers = _get_library_root_logger().handlers
for handler in handlers:
handler.setFormatter(None)
def warning_advice(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
if no_advisory_warnings:
return
self.warning(*args, **kwargs)
logging.Logger.warning_advice = warning_advice
class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
self._iterator = args[0] if args else None
def __iter__(self):
return iter(self._iterator)
def __getattr__(self, _):
"""Return empty function."""
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
return
return empty_fn
def __enter__(self):
return self
def __exit__(self, type_, value, traceback):
return
class _tqdm_cls:
def __call__(self, *args, **kwargs):
if _tqdm_active:
return tqdm_lib.tqdm(*args, **kwargs)
else:
return EmptyTqdm(*args, **kwargs)
def set_lock(self, *args, **kwargs):
self._lock = None
if _tqdm_active:
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
def get_lock(self):
if _tqdm_active:
return tqdm_lib.tqdm.get_lock()
tqdm = _tqdm_cls()
def is_progress_bar_enabled() -> bool:
"""Return a boolean indicating whether tqdm progress bars are enabled."""
global _tqdm_active
return bool(_tqdm_active)
def enable_progress_bar():
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = True
def disable_progress_bar():
"""Disable tqdm progress bar."""
global _tqdm_active
_tqdm_active = False

View File

@ -24,6 +24,7 @@ import torch
from diffusers import GaussianDDPMScheduler, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from models.vision.ddpm.modeling_ddpm import DDPM from models.vision.ddpm.modeling_ddpm import DDPM
from models.vision.ddim.modeling_ddim import DDIM from models.vision.ddim.modeling_ddim import DDIM
@ -78,6 +79,45 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self):
with self.assertRaises(ValueError):
ConfigMixin.from_config("dummy_path")
def test_save_load(self):
class SampleObject(ConfigMixin):
config_name = "config.json"
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
):
self.register(a=a, b=b, c=c, d=d, e=e)
obj = SampleObject()
config = obj.config
assert config["a"] == 2
assert config["b"] == 5
assert config["c"] == (2, 5)
assert config["d"] == "for diffusion"
assert config["e"] == [1, 3]
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
new_obj = SampleObject.from_config(tmpdirname)
new_config = new_obj.config
assert config.pop("c") == (2, 5) # instantiated as tuple
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config
class ModelTesterMixin(unittest.TestCase): class ModelTesterMixin(unittest.TestCase):
@property @property
def dummy_input(self): def dummy_input(self):