diff --git a/setup.py b/setup.py index 17a5dc36..96d1f309 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2021 The HuggingFace Team. All rights reserved. +# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,11 +52,11 @@ To create the package for pypi. twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ Check that you can install it in a virtualenv by running: - pip install -i https://testpypi.python.org/pypi transformers + pip install -i https://testpypi.python.org/pypi diffusers Check you can run the following commands: - python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))" - python -c "from transformers import *" + python -c "from diffusers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))" + python -c "from diffusers import *" 9. Upload the final version to actual pypi: twine upload dist/* -r pypi @@ -77,36 +77,21 @@ from setuptools import find_packages, setup # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py _deps = [ "Pillow", - "accelerate>=0.9.0", "black~=22.0,>=22.3", - "codecarbon==1.2.0", - "dataclasses", - "datasets", - "GitPython<3.1.19", - "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.1.0,<1.0", - "importlib_metadata", + "filelock", + "flake8>=3.8.3", + "huggingface-hub", "isort>=5.5.4", - "numpy>=1.17", + "numpy", "pytest", - "pytest-timeout", - "pytest-xdist", - "python>=3.7.0", - "regex!=2019.12.17", "requests", - "sagemaker>=2.31.0", - "tokenizers>=0.11.1,!=0.11.3,<0.13", "torch>=1.4", - "torchaudio", - "tqdm>=4.27", - "unidic>=1.0.2", - "unidic_lite>=1.0.7", - "uvicorn", + "torchvision", ] # this is a lookup table with items like: # -# tokenizers: "tokenizers==0.9.4" +# tokenizers: "huggingface-hub==0.8.0" # packaging: "packaging" # # some of the values are versioned whereas others aren't. @@ -176,15 +161,17 @@ extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"] extras["docs"] = [] extras["test"] = [ "pytest", - "pytest-xdist", - "pytest-subtests", - "datasets", - "transformers", ] extras["dev"] = extras["quality"] + extras["test"] -extras["sagemaker"] = [ - "sagemaker", # boto3 is a required package in sagemaker +install_requires = [ + deps["filelock"], + deps["huggingface-hub"], + deps["numpy"], + deps["requests"], + deps["torch"], + deps["torchvision"], + deps["Pillow"], ] setup( @@ -201,7 +188,7 @@ setup( package_dir={"": "src"}, packages=find_packages("src"), python_requires=">=3.6.0", - install_requires=["numpy>=1.17", "packaging>=20.0", "pyyaml", "torch>=1.4.0"], + install_requires=install_requires, extras_require=extras, classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index c3cf8c59..fc7e04cc 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -24,18 +24,19 @@ import re from typing import Any, Dict, Tuple, Union from requests import HTTPError -from transformers.utils import ( +from huggingface_hub import hf_hub_download + + +from .utils import ( HUGGINGFACE_CO_RESOLVE_ENDPOINT, + DIFFUSERS_CACHE, EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, - cached_path, - hf_bucket_url, - is_offline_mode, - is_remote_url, logging, ) + from . import __version__ @@ -56,6 +57,8 @@ class ConfigMixin: if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") kwargs["_class_name"] = self.__class__.__name__ + kwargs["_diffusers_version"] = __version__ + for key, value in kwargs.items(): try: setattr(self, key, value) @@ -90,11 +93,26 @@ class ConfigMixin: self.to_json_file(output_config_file) logger.info(f"ConfigMixinuration saved in {output_config_file}") + @classmethod + def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + config_dict = cls.get_config_dict( + pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + ) + + init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + + model = cls(**init_dict) + + if return_unused_kwargs: + return model, unused_kwargs + else: + return model + @classmethod def get_config_dict( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - cache_dir = kwargs.pop("cache_dir", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) @@ -104,85 +122,83 @@ class ConfigMixin: user_agent = {"file_type": "config"} - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): - config_file = pretrained_model_name_or_path - else: - configuration_file = cls.config_name - if os.path.isdir(pretrained_model_name_or_path): - config_file = os.path.join(pretrained_model_name_or_path, configuration_file) + if cls.config_name is None: + raise ValueError( + "`self.config_name` is not defined. Note that one should not load a config from " + "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" + ) + + if os.path.isfile(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + elif os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): + # Load from a PyTorch checkpoint + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) else: - config_file = hf_bucket_url( - pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None + raise EnvironmentError( + f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + config_file = hf_hub_download( + pretrained_model_name_or_path, + filename=cls.config_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on " + "'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having " + "permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass " + "`use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " + f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for " + "available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" + f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" + f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the" + " library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {cls.config_name} file" ) - try: - # Load from URL or cache if already cached - resolved_config_file = cached_path( - config_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - ) - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on " - "'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having " - "permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass " - "`use_auth_token=True`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " - f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for " - "available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}." - ) - except HTTPError as err: - raise EnvironmentError( - f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" - f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" - f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the" - " library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a {configuration_file} file" - ) - try: # Load config dict - config_dict = cls._dict_from_json_file(resolved_config_file) + config_dict = cls._dict_from_json_file(config_file) except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError( - f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." + f"It looks like the config file at '{config_file}' is not a valid JSON file." ) - if resolved_config_file == config_file: - logger.info(f"loading configuration file {config_file}") - else: - logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") - return config_dict @classmethod @@ -208,19 +224,6 @@ class ConfigMixin: return init_dict, unused_kwargs - @classmethod - def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): - config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) - - model = cls(**init_dict) - - if return_unused_kwargs: - return model, unused_kwargs - else: - return model - @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): with open(json_file, "r", encoding="utf-8") as reader: @@ -233,18 +236,9 @@ class ConfigMixin: def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this instance to a Python dictionary. - - Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - output = copy.deepcopy(self.__dict__) - - # Diffusion version when serializing the model - output["diffusers_version"] = __version__ - + @property + def config(self) -> Dict[str, Any]: + output = copy.deepcopy(self._dict_to_save) return output def to_json_string(self) -> str: diff --git a/src/diffusers/dynamic_modules_utils.py b/src/diffusers/dynamic_modules_utils.py index a433c209..9e89a51e 100644 --- a/src/diffusers/dynamic_modules_utils.py +++ b/src/diffusers/dynamic_modules_utils.py @@ -22,16 +22,8 @@ import sys from pathlib import Path from typing import Dict, Optional, Union -from huggingface_hub import HfFolder, model_info - -from transformers.utils import ( - HF_MODULES_CACHE, - TRANSFORMERS_DYNAMIC_MODULE_NAME, - cached_path, - hf_bucket_url, - is_offline_mode, - logging, -) +from huggingface_hub import cached_download +from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -219,7 +211,7 @@ def get_cached_module_file( try: # Load from URL or cache if already cached - resolved_module_file = cached_path( + resolved_module_file = cached_download( module_file_or_url, cache_dir=cache_dir, force_download=force_download, @@ -237,7 +229,7 @@ def get_cached_module_file( modules_needed = check_imports(resolved_module_file) # Now we move the module inside our cached dynamic modules. - full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule # We always copy local files (we could hash the file to see if there was a change, and give them the name of diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 54437fa5..20870e34 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -21,18 +21,15 @@ import torch from torch import Tensor, device from requests import HTTPError +from huggingface_hub import hf_hub_download -# CHANGE to diffusers.utils -from transformers.utils import ( +from .utils import ( CONFIG_NAME, + DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, - cached_path, - hf_bucket_url, - is_offline_mode, - is_remote_url, logging, ) @@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module): """ - cache_dir = kwargs.pop("cache_dir", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module): local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) - mirror = kwargs.pop("mirror", None) from_auto_class = kwargs.pop("_from_auto", False) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path model, unused_kwargs = cls.from_config( @@ -353,79 +345,67 @@ class ModelMixin(torch.nn.Module): if os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." ) - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): - archive_file = pretrained_model_name_or_path else: - filename = WEIGHTS_NAME + try: + # Load from URL or cache if already cached + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + ) - archive_file = hf_bucket_url( - pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror - ) + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.") + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {WEIGHTS_NAME} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {WEIGHTS_NAME}" + ) - try: - # Load from URL or cache if already cached - resolved_archive_file = cached_path( - archive_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - ) - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login` and pass `use_auth_token=True`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {filename}.") - except HTTPError as err: - raise EnvironmentError( - f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {WEIGHTS_NAME} or" - " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {WEIGHTS_NAME}" - ) - - if resolved_archive_file == archive_file: - logger.info(f"loading weights file {archive_file}") - else: - logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") - - # restore default dtype - state_dict = load_state_dict(resolved_archive_file) + # restore default dtype + state_dict = load_state_dict(model_file) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, - resolved_archive_file, + model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, ) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index cd69b9cf..e836b765 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -20,8 +20,7 @@ from typing import Optional, Union from huggingface_hub import snapshot_download -# CHANGE to diffusers.utils -from transformers.utils import logging +from .utils import logging, DIFFUSERS_CACHE from .configuration_utils import ConfigMixin from .dynamic_modules_utils import get_class_from_dynamic_module @@ -76,11 +75,12 @@ class DiffusionPipeline(ConfigMixin): def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.save_config(save_directory) - model_index_dict = self._dict_to_save + model_index_dict = self.config model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") model_index_dict.pop("_module") - for name, (library_name, class_name) in self._dict_to_save.items(): + for name, (library_name, class_name) in model_index_dict.items(): importable_classes = LOADABLE_CLASSES[library_name] # TODO: Suraj @@ -101,14 +101,36 @@ class DiffusionPipeline(ConfigMixin): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Add docstrings + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): - cached_folder = snapshot_download(pretrained_model_name_or_path) + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + output_loading_info=output_loading_info, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) else: cached_folder = pretrained_model_name_or_path config_dict = cls.get_config_dict(cached_folder) + module = config_dict["_module"] + class_name_ = config_dict["_class_name"] module_candidate = config_dict["_module"] module_candidate_name = module_candidate.replace(".py", "") @@ -126,13 +148,12 @@ class DiffusionPipeline(ConfigMixin): init_kwargs = {} # get all importable classes to get the load method name for custom models/components - # here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers + # here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers all_importable_classes = {} for library in LOADABLE_CLASSES: all_importable_classes.update(LOADABLE_CLASSES[library]) for name, (library_name, class_name) in init_dict.items(): - # if the model is not in diffusers or transformers, we need to load it from the hub # assumes that it's a subclass of ModelMixin if library_name == module_candidate_name: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py new file mode 100644 index 00000000..45b9f64e --- /dev/null +++ b/src/diffusers/utils/__init__.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# coding=utf-8 + +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from requests.exceptions import HTTPError +import os + +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +default_cache_path = os.path.join(hf_cache_home, "diffusers") + + +CONFIG_NAME = "config.json" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" +DIFFUSERS_CACHE = default_cache_path +DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) + + +class RepositoryNotFoundError(HTTPError): + """ + Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does + not have access to. + """ + + +class EntryNotFoundError(HTTPError): + """Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename.""" + + +class RevisionNotFoundError(HTTPError): + """Raised when trying to access a hf.co URL with a valid repository but an invalid revision.""" diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py new file mode 100644 index 00000000..09abb570 --- /dev/null +++ b/src/diffusers/utils/logging.py @@ -0,0 +1,344 @@ +# coding=utf-8 +# Copyright 2020 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Optional + +from tqdm import auto as tqdm_lib + + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level(): + """ + If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, " + f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom diffusers module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the 🤗 Transformers's root logger as an int. + + Returns: + `int`: The logging level. + + + + 🤗 Transformers has following logging levels: + + - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - 40: `diffusers.logging.ERROR` + - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - 20: `diffusers.logging.INFO` + - 10: `diffusers.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 Transformers's root logger. + + Args: + verbosity (`int`): + Logging level, e.g., one of: + + - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` + - `diffusers.logging.ERROR` + - `diffusers.logging.WARNING` or `diffusers.logging.WARN` + - `diffusers.logging.INFO` + - `diffusers.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the HuggingFace Transformers's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the HuggingFace Transformers's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Transformers's root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Transformers's root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to + prevent double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for HuggingFace Transformers's loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index c5b18e4a..332d75ed 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -24,6 +24,7 @@ import torch from diffusers import GaussianDDPMScheduler, UNetModel from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.configuration_utils import ConfigMixin from models.vision.ddpm.modeling_ddpm import DDPM from models.vision.ddim.modeling_ddim import DDIM @@ -78,6 +79,45 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() +class ConfigTester(unittest.TestCase): + def test_load_not_from_mixin(self): + with self.assertRaises(ValueError): + ConfigMixin.from_config("dummy_path") + + def test_save_load(self): + + class SampleObject(ConfigMixin): + config_name = "config.json" + + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + ): + self.register(a=a, b=b, c=c, d=d, e=e) + + obj = SampleObject() + config = obj.config + + assert config["a"] == 2 + assert config["b"] == 5 + assert config["c"] == (2, 5) + assert config["d"] == "for diffusion" + assert config["e"] == [1, 3] + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + new_obj = SampleObject.from_config(tmpdirname) + new_config = new_obj.config + + assert config.pop("c") == (2, 5) # instantiated as tuple + assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json + assert config == new_config + + class ModelTesterMixin(unittest.TestCase): @property def dummy_input(self):