Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
commit
4f761e95c7
51
setup.py
51
setup.py
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -52,11 +52,11 @@ To create the package for pypi.
|
|||
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
|
||||
|
||||
Check that you can install it in a virtualenv by running:
|
||||
pip install -i https://testpypi.python.org/pypi transformers
|
||||
pip install -i https://testpypi.python.org/pypi diffusers
|
||||
|
||||
Check you can run the following commands:
|
||||
python -c "from transformers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
|
||||
python -c "from transformers import *"
|
||||
python -c "from diffusers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
|
||||
python -c "from diffusers import *"
|
||||
|
||||
9. Upload the final version to actual pypi:
|
||||
twine upload dist/* -r pypi
|
||||
|
@ -77,36 +77,21 @@ from setuptools import find_packages, setup
|
|||
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow",
|
||||
"accelerate>=0.9.0",
|
||||
"black~=22.0,>=22.3",
|
||||
"codecarbon==1.2.0",
|
||||
"dataclasses",
|
||||
"datasets",
|
||||
"GitPython<3.1.19",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.1.0,<1.0",
|
||||
"importlib_metadata",
|
||||
"filelock",
|
||||
"flake8>=3.8.3",
|
||||
"huggingface-hub",
|
||||
"isort>=5.5.4",
|
||||
"numpy>=1.17",
|
||||
"numpy",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.7.0",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"sagemaker>=2.31.0",
|
||||
"tokenizers>=0.11.1,!=0.11.3,<0.13",
|
||||
"torch>=1.4",
|
||||
"torchaudio",
|
||||
"tqdm>=4.27",
|
||||
"unidic>=1.0.2",
|
||||
"unidic_lite>=1.0.7",
|
||||
"uvicorn",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
#
|
||||
# tokenizers: "tokenizers==0.9.4"
|
||||
# tokenizers: "huggingface-hub==0.8.0"
|
||||
# packaging: "packaging"
|
||||
#
|
||||
# some of the values are versioned whereas others aren't.
|
||||
|
@ -176,15 +161,17 @@ extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
|
|||
extras["docs"] = []
|
||||
extras["test"] = [
|
||||
"pytest",
|
||||
"pytest-xdist",
|
||||
"pytest-subtests",
|
||||
"datasets",
|
||||
"transformers",
|
||||
]
|
||||
extras["dev"] = extras["quality"] + extras["test"]
|
||||
|
||||
extras["sagemaker"] = [
|
||||
"sagemaker", # boto3 is a required package in sagemaker
|
||||
install_requires = [
|
||||
deps["filelock"],
|
||||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
deps["requests"],
|
||||
deps["torch"],
|
||||
deps["torchvision"],
|
||||
deps["Pillow"],
|
||||
]
|
||||
|
||||
setup(
|
||||
|
@ -201,7 +188,7 @@ setup(
|
|||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
python_requires=">=3.6.0",
|
||||
install_requires=["numpy>=1.17", "packaging>=20.0", "pyyaml", "torch>=1.4.0"],
|
||||
install_requires=install_requires,
|
||||
extras_require=extras,
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
|
|
|
@ -24,18 +24,19 @@ import re
|
|||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from requests import HTTPError
|
||||
from transformers.utils import (
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
from .utils import (
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
DIFFUSERS_CACHE,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
from . import __version__
|
||||
|
||||
|
||||
|
@ -56,6 +57,8 @@ class ConfigMixin:
|
|||
if self.config_name is None:
|
||||
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
||||
kwargs["_class_name"] = self.__class__.__name__
|
||||
kwargs["_diffusers_version"] = __version__
|
||||
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
|
@ -90,11 +93,26 @@ class ConfigMixin:
|
|||
self.to_json_file(output_config_file)
|
||||
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
config_dict = cls.get_config_dict(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
model = cls(**init_dict)
|
||||
|
||||
if return_unused_kwargs:
|
||||
return model, unused_kwargs
|
||||
else:
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
|
@ -104,85 +122,83 @@ class ConfigMixin:
|
|||
|
||||
user_agent = {"file_type": "config"}
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
configuration_file = cls.config_name
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||
if cls.config_name is None:
|
||||
raise ValueError(
|
||||
"`self.config_name` is not defined. Note that one should not load a config from "
|
||||
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
||||
)
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
|
||||
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {cls.config_name} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
|
||||
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {configuration_file} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||
f"It looks like the config file at '{config_file}' is not a valid JSON file."
|
||||
)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
logger.info(f"loading configuration file {config_file}")
|
||||
else:
|
||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
||||
|
||||
return config_dict
|
||||
|
||||
@classmethod
|
||||
|
@ -208,19 +224,6 @@ class ConfigMixin:
|
|||
|
||||
return init_dict, unused_kwargs
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
model = cls(**init_dict)
|
||||
|
||||
if return_unused_kwargs:
|
||||
return model, unused_kwargs
|
||||
else:
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
|
@ -233,18 +236,9 @@ class ConfigMixin:
|
|||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
|
||||
# Diffusion version when serializing the model
|
||||
output["diffusers_version"] = __version__
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
output = copy.deepcopy(self._dict_to_save)
|
||||
return output
|
||||
|
||||
def to_json_string(self) -> str:
|
||||
|
|
|
@ -22,16 +22,8 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from huggingface_hub import HfFolder, model_info
|
||||
|
||||
from transformers.utils import (
|
||||
HF_MODULES_CACHE,
|
||||
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
logging,
|
||||
)
|
||||
from huggingface_hub import cached_download
|
||||
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
@ -219,7 +211,7 @@ def get_cached_module_file(
|
|||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_module_file = cached_path(
|
||||
resolved_module_file = cached_download(
|
||||
module_file_or_url,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
|
@ -237,7 +229,7 @@ def get_cached_module_file(
|
|||
modules_needed = check_imports(resolved_module_file)
|
||||
|
||||
# Now we move the module inside our cached dynamic modules.
|
||||
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||
create_dynamic_module(full_submodule)
|
||||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
||||
|
|
|
@ -21,18 +21,15 @@ import torch
|
|||
from torch import Tensor, device
|
||||
|
||||
from requests import HTTPError
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import (
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
|
|||
</Tip>
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
|
@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
|
|||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
model, unused_kwargs = cls.from_config(
|
||||
|
@ -353,79 +345,67 @@ class ModelMixin(torch.nn.Module):
|
|||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
filename = WEIGHTS_NAME
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {filename}.")
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}"
|
||||
)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
|
||||
# restore default dtype
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
# restore default dtype
|
||||
state_dict = load_state_dict(model_file)
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
resolved_archive_file,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
|
|
@ -20,8 +20,7 @@ from typing import Optional, Union
|
|||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import logging
|
||||
from .utils import logging, DIFFUSERS_CACHE
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
|
@ -80,11 +79,12 @@ class DiffusionPipeline(ConfigMixin):
|
|||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = self._dict_to_save
|
||||
model_index_dict = self.config
|
||||
model_index_dict.pop("_class_name")
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_module")
|
||||
|
||||
for name, (library_name, class_name) in self._dict_to_save.items():
|
||||
for name, (library_name, class_name) in model_index_dict.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
# TODO: Suraj
|
||||
|
@ -105,14 +105,36 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Add docstrings
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
cached_folder = snapshot_download(pretrained_model_name_or_path)
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
output_loading_info=output_loading_info,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
module = config_dict["_module"]
|
||||
class_name_ = config_dict["_class_name"]
|
||||
module_candidate = config_dict["_module"]
|
||||
module_candidate_name = module_candidate.replace(".py", "")
|
||||
|
||||
|
@ -130,7 +152,6 @@ class DiffusionPipeline(ConfigMixin):
|
|||
init_kwargs = {}
|
||||
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
|
||||
# if the model is not in diffusers or transformers, we need to load it from the hub
|
||||
# assumes that it's a subclass of ModelMixin
|
||||
if library_name == module_candidate_name:
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from requests.exceptions import HTTPError
|
||||
import os
|
||||
|
||||
hf_cache_home = os.path.expanduser(
|
||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||
)
|
||||
default_cache_path = os.path.join(hf_cache_home, "diffusers")
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
|
||||
class RepositoryNotFoundError(HTTPError):
|
||||
"""
|
||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||
not have access to.
|
||||
"""
|
||||
|
||||
|
||||
class EntryNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
||||
|
||||
|
||||
class RevisionNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
|
@ -0,0 +1,344 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 Optuna, Hugging Face
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Logging utilities."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from logging import CRITICAL # NOQA
|
||||
from logging import DEBUG # NOQA
|
||||
from logging import ERROR # NOQA
|
||||
from logging import FATAL # NOQA
|
||||
from logging import INFO # NOQA
|
||||
from logging import NOTSET # NOQA
|
||||
from logging import WARN # NOQA
|
||||
from logging import WARNING # NOQA
|
||||
from typing import Optional
|
||||
|
||||
from tqdm import auto as tqdm_lib
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_default_handler: Optional[logging.Handler] = None
|
||||
|
||||
log_levels = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}
|
||||
|
||||
_default_log_level = logging.WARNING
|
||||
|
||||
_tqdm_active = True
|
||||
|
||||
|
||||
def _get_default_logging_level():
|
||||
"""
|
||||
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
||||
not - fall back to `_default_log_level`
|
||||
"""
|
||||
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str in log_levels:
|
||||
return log_levels[env_level_str]
|
||||
else:
|
||||
logging.getLogger().warning(
|
||||
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
|
||||
f"has to be one of: { ', '.join(log_levels.keys()) }"
|
||||
)
|
||||
return _default_log_level
|
||||
|
||||
|
||||
def _get_library_name() -> str:
|
||||
|
||||
return __name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_library_root_logger() -> logging.Logger:
|
||||
|
||||
return logging.getLogger(_get_library_name())
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
|
||||
global _default_handler
|
||||
|
||||
with _lock:
|
||||
if _default_handler:
|
||||
# This library has already configured the library root logger.
|
||||
return
|
||||
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
||||
_default_handler.flush = sys.stderr.flush
|
||||
|
||||
# Apply our default configuration to the library root logger.
|
||||
library_root_logger = _get_library_root_logger()
|
||||
library_root_logger.addHandler(_default_handler)
|
||||
library_root_logger.setLevel(_get_default_logging_level())
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def _reset_library_root_logger() -> None:
|
||||
|
||||
global _default_handler
|
||||
|
||||
with _lock:
|
||||
if not _default_handler:
|
||||
return
|
||||
|
||||
library_root_logger = _get_library_root_logger()
|
||||
library_root_logger.removeHandler(_default_handler)
|
||||
library_root_logger.setLevel(logging.NOTSET)
|
||||
_default_handler = None
|
||||
|
||||
|
||||
def get_log_levels_dict():
|
||||
return log_levels
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
"""
|
||||
Return a logger with the specified name.
|
||||
|
||||
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
|
||||
"""
|
||||
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
_configure_library_root_logger()
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def get_verbosity() -> int:
|
||||
"""
|
||||
Return the current level for the 🤗 Transformers's root logger as an int.
|
||||
|
||||
Returns:
|
||||
`int`: The logging level.
|
||||
|
||||
<Tip>
|
||||
|
||||
🤗 Transformers has following logging levels:
|
||||
|
||||
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
||||
- 40: `diffusers.logging.ERROR`
|
||||
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
||||
- 20: `diffusers.logging.INFO`
|
||||
- 10: `diffusers.logging.DEBUG`
|
||||
|
||||
</Tip>"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
return _get_library_root_logger().getEffectiveLevel()
|
||||
|
||||
|
||||
def set_verbosity(verbosity: int) -> None:
|
||||
"""
|
||||
Set the verbosity level for the 🤗 Transformers's root logger.
|
||||
|
||||
Args:
|
||||
verbosity (`int`):
|
||||
Logging level, e.g., one of:
|
||||
|
||||
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
||||
- `diffusers.logging.ERROR`
|
||||
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
||||
- `diffusers.logging.INFO`
|
||||
- `diffusers.logging.DEBUG`
|
||||
"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().setLevel(verbosity)
|
||||
|
||||
|
||||
def set_verbosity_info():
|
||||
"""Set the verbosity to the `INFO` level."""
|
||||
return set_verbosity(INFO)
|
||||
|
||||
|
||||
def set_verbosity_warning():
|
||||
"""Set the verbosity to the `WARNING` level."""
|
||||
return set_verbosity(WARNING)
|
||||
|
||||
|
||||
def set_verbosity_debug():
|
||||
"""Set the verbosity to the `DEBUG` level."""
|
||||
return set_verbosity(DEBUG)
|
||||
|
||||
|
||||
def set_verbosity_error():
|
||||
"""Set the verbosity to the `ERROR` level."""
|
||||
return set_verbosity(ERROR)
|
||||
|
||||
|
||||
def disable_default_handler() -> None:
|
||||
"""Disable the default handler of the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert _default_handler is not None
|
||||
_get_library_root_logger().removeHandler(_default_handler)
|
||||
|
||||
|
||||
def enable_default_handler() -> None:
|
||||
"""Enable the default handler of the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert _default_handler is not None
|
||||
_get_library_root_logger().addHandler(_default_handler)
|
||||
|
||||
|
||||
def add_handler(handler: logging.Handler) -> None:
|
||||
"""adds a handler to the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert handler is not None
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
"""removes given handler from the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert handler is not None and handler not in _get_library_root_logger().handlers
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
def disable_propagation() -> None:
|
||||
"""
|
||||
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
||||
"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().propagate = False
|
||||
|
||||
|
||||
def enable_propagation() -> None:
|
||||
"""
|
||||
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
|
||||
prevent double logging if the root logger has been configured.
|
||||
"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().propagate = True
|
||||
|
||||
|
||||
def enable_explicit_format() -> None:
|
||||
"""
|
||||
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
|
||||
```
|
||||
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
||||
```
|
||||
All handlers currently bound to the root logger are affected by this method.
|
||||
"""
|
||||
handlers = _get_library_root_logger().handlers
|
||||
|
||||
for handler in handlers:
|
||||
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
|
||||
def reset_format() -> None:
|
||||
"""
|
||||
Resets the formatting for HuggingFace Transformers's loggers.
|
||||
|
||||
All handlers currently bound to the root logger are affected by this method.
|
||||
"""
|
||||
handlers = _get_library_root_logger().handlers
|
||||
|
||||
for handler in handlers:
|
||||
handler.setFormatter(None)
|
||||
|
||||
|
||||
def warning_advice(self, *args, **kwargs):
|
||||
"""
|
||||
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
|
||||
warning will not be printed
|
||||
"""
|
||||
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
|
||||
if no_advisory_warnings:
|
||||
return
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.warning_advice = warning_advice
|
||||
|
||||
|
||||
class EmptyTqdm:
|
||||
"""Dummy tqdm which doesn't do anything."""
|
||||
|
||||
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||
self._iterator = args[0] if args else None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._iterator)
|
||||
|
||||
def __getattr__(self, _):
|
||||
"""Return empty function."""
|
||||
|
||||
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return
|
||||
|
||||
return empty_fn
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type_, value, traceback):
|
||||
return
|
||||
|
||||
|
||||
class _tqdm_cls:
|
||||
def __call__(self, *args, **kwargs):
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm(*args, **kwargs)
|
||||
else:
|
||||
return EmptyTqdm(*args, **kwargs)
|
||||
|
||||
def set_lock(self, *args, **kwargs):
|
||||
self._lock = None
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
||||
|
||||
def get_lock(self):
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm.get_lock()
|
||||
|
||||
|
||||
tqdm = _tqdm_cls()
|
||||
|
||||
|
||||
def is_progress_bar_enabled() -> bool:
|
||||
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
||||
global _tqdm_active
|
||||
return bool(_tqdm_active)
|
||||
|
||||
|
||||
def enable_progress_bar():
|
||||
"""Enable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = True
|
||||
|
||||
|
||||
def disable_progress_bar():
|
||||
"""Disable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = False
|
|
@ -24,6 +24,7 @@ import torch
|
|||
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from models.vision.ddpm.modeling_ddpm import DDPM
|
||||
from models.vision.ddim.modeling_ddim import DDIM
|
||||
|
||||
|
@ -78,6 +79,45 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
|||
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
|
||||
|
||||
|
||||
class ConfigTester(unittest.TestCase):
|
||||
def test_load_not_from_mixin(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ConfigMixin.from_config("dummy_path")
|
||||
|
||||
def test_save_load(self):
|
||||
|
||||
class SampleObject(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
e=[1, 3],
|
||||
):
|
||||
self.register(a=a, b=b, c=c, d=d, e=e)
|
||||
|
||||
obj = SampleObject()
|
||||
config = obj.config
|
||||
|
||||
assert config["a"] == 2
|
||||
assert config["b"] == 5
|
||||
assert config["c"] == (2, 5)
|
||||
assert config["d"] == "for diffusion"
|
||||
assert config["e"] == [1, 3]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
new_obj = SampleObject.from_config(tmpdirname)
|
||||
new_config = new_obj.config
|
||||
|
||||
assert config.pop("c") == (2, 5) # instantiated as tuple
|
||||
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
|
||||
assert config == new_config
|
||||
|
||||
|
||||
class ModelTesterMixin(unittest.TestCase):
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
|
Loading…
Reference in New Issue