remove transformers dependency
This commit is contained in:
parent
5a784f98a6
commit
09e1b0b46f
|
@ -24,18 +24,19 @@ import re
|
|||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from requests import HTTPError
|
||||
from transformers.utils import (
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
from .utils import (
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
DIFFUSERS_CACHE,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
from . import __version__
|
||||
|
||||
|
||||
|
@ -89,13 +90,12 @@ class ConfigMixin:
|
|||
|
||||
self.to_json_file(output_config_file)
|
||||
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
|
@ -105,85 +105,77 @@ class ConfigMixin:
|
|||
|
||||
user_agent = {"file_type": "config"}
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
configuration_file = cls.config_name
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
config_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=cls.config_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
|
||||
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {cls.config_name} file"
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
|
||||
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {configuration_file} file"
|
||||
)
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{config_file}' is not a valid JSON file."
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||
)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
logger.info(f"loading configuration file {config_file}")
|
||||
else:
|
||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
||||
|
||||
return config_dict
|
||||
|
||||
@classmethod
|
||||
|
@ -199,9 +191,7 @@ class ConfigMixin:
|
|||
# use value from config dict
|
||||
init_dict[key] = config_dict.pop(key)
|
||||
|
||||
|
||||
unused_kwargs = config_dict.update(kwargs)
|
||||
|
||||
passed_keys = set(init_dict.keys())
|
||||
if len(expected_keys - passed_keys) > 0:
|
||||
logger.warn(
|
||||
|
|
|
@ -22,16 +22,8 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from huggingface_hub import HfFolder, model_info
|
||||
|
||||
from transformers.utils import (
|
||||
HF_MODULES_CACHE,
|
||||
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
logging,
|
||||
)
|
||||
from huggingface_hub import cached_download
|
||||
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
@ -219,7 +211,7 @@ def get_cached_module_file(
|
|||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_module_file = cached_path(
|
||||
resolved_module_file = cached_download(
|
||||
module_file_or_url,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
|
@ -237,7 +229,7 @@ def get_cached_module_file(
|
|||
modules_needed = check_imports(resolved_module_file)
|
||||
|
||||
# Now we move the module inside our cached dynamic modules.
|
||||
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||
full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||
create_dynamic_module(full_submodule)
|
||||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
||||
|
|
|
@ -21,18 +21,15 @@ import torch
|
|||
from torch import Tensor, device
|
||||
|
||||
from requests import HTTPError
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import (
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
|
|||
</Tip>
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
|
@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
|
|||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
model, unused_kwargs = cls.from_config(
|
||||
|
@ -353,83 +345,71 @@ class ModelMixin(torch.nn.Module):
|
|||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
filename = WEIGHTS_NAME
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
# restore default dtype
|
||||
state_dict = load_state_dict(model_file)
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {filename}.")
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}"
|
||||
)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
|
||||
# restore default dtype
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
|
|
|
@ -19,8 +19,7 @@ import os
|
|||
from typing import Optional, Union
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import logging
|
||||
from .utils import logging, DIFFUSERS_CACHE
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
|
@ -55,14 +54,13 @@ class DiffusionPipeline(ConfigMixin):
|
|||
class_name = module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
|
||||
|
||||
# save model index config
|
||||
self.register(**register_dict)
|
||||
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
|
||||
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
|
||||
self.register(**register_dict)
|
||||
|
||||
|
@ -94,22 +92,41 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Add docstrings
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
cached_folder = snapshot_download(pretrained_model_name_or_path)
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
output_loading_info=output_loading_info,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
|
||||
module = config_dict["_module"]
|
||||
class_name_ = config_dict["_class_name"]
|
||||
|
||||
|
||||
if class_name_ == cls.__name__:
|
||||
pipeline_class = cls
|
||||
else:
|
||||
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
hf_cache_home = os.path.expanduser(
|
||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||
)
|
||||
default_cache_path = os.path.join(hf_cache_home, "diffusers")
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
|
||||
class RepositoryNotFoundError(HTTPError):
|
||||
"""
|
||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||
not have access to.
|
||||
"""
|
||||
|
||||
|
||||
class EntryNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
||||
|
||||
|
||||
class RevisionNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
|
@ -0,0 +1,344 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 Optuna, Hugging Face
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Logging utilities."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from logging import CRITICAL # NOQA
|
||||
from logging import DEBUG # NOQA
|
||||
from logging import ERROR # NOQA
|
||||
from logging import FATAL # NOQA
|
||||
from logging import INFO # NOQA
|
||||
from logging import NOTSET # NOQA
|
||||
from logging import WARN # NOQA
|
||||
from logging import WARNING # NOQA
|
||||
from typing import Optional
|
||||
|
||||
from tqdm import auto as tqdm_lib
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_default_handler: Optional[logging.Handler] = None
|
||||
|
||||
log_levels = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}
|
||||
|
||||
_default_log_level = logging.WARNING
|
||||
|
||||
_tqdm_active = True
|
||||
|
||||
|
||||
def _get_default_logging_level():
|
||||
"""
|
||||
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
||||
not - fall back to `_default_log_level`
|
||||
"""
|
||||
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str in log_levels:
|
||||
return log_levels[env_level_str]
|
||||
else:
|
||||
logging.getLogger().warning(
|
||||
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, "
|
||||
f"has to be one of: { ', '.join(log_levels.keys()) }"
|
||||
)
|
||||
return _default_log_level
|
||||
|
||||
|
||||
def _get_library_name() -> str:
|
||||
|
||||
return __name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_library_root_logger() -> logging.Logger:
|
||||
|
||||
return logging.getLogger(_get_library_name())
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
|
||||
global _default_handler
|
||||
|
||||
with _lock:
|
||||
if _default_handler:
|
||||
# This library has already configured the library root logger.
|
||||
return
|
||||
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
||||
_default_handler.flush = sys.stderr.flush
|
||||
|
||||
# Apply our default configuration to the library root logger.
|
||||
library_root_logger = _get_library_root_logger()
|
||||
library_root_logger.addHandler(_default_handler)
|
||||
library_root_logger.setLevel(_get_default_logging_level())
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def _reset_library_root_logger() -> None:
|
||||
|
||||
global _default_handler
|
||||
|
||||
with _lock:
|
||||
if not _default_handler:
|
||||
return
|
||||
|
||||
library_root_logger = _get_library_root_logger()
|
||||
library_root_logger.removeHandler(_default_handler)
|
||||
library_root_logger.setLevel(logging.NOTSET)
|
||||
_default_handler = None
|
||||
|
||||
|
||||
def get_log_levels_dict():
|
||||
return log_levels
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
"""
|
||||
Return a logger with the specified name.
|
||||
|
||||
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
|
||||
"""
|
||||
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
_configure_library_root_logger()
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def get_verbosity() -> int:
|
||||
"""
|
||||
Return the current level for the 🤗 Transformers's root logger as an int.
|
||||
|
||||
Returns:
|
||||
`int`: The logging level.
|
||||
|
||||
<Tip>
|
||||
|
||||
🤗 Transformers has following logging levels:
|
||||
|
||||
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
||||
- 40: `diffusers.logging.ERROR`
|
||||
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
||||
- 20: `diffusers.logging.INFO`
|
||||
- 10: `diffusers.logging.DEBUG`
|
||||
|
||||
</Tip>"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
return _get_library_root_logger().getEffectiveLevel()
|
||||
|
||||
|
||||
def set_verbosity(verbosity: int) -> None:
|
||||
"""
|
||||
Set the verbosity level for the 🤗 Transformers's root logger.
|
||||
|
||||
Args:
|
||||
verbosity (`int`):
|
||||
Logging level, e.g., one of:
|
||||
|
||||
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
||||
- `diffusers.logging.ERROR`
|
||||
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
||||
- `diffusers.logging.INFO`
|
||||
- `diffusers.logging.DEBUG`
|
||||
"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().setLevel(verbosity)
|
||||
|
||||
|
||||
def set_verbosity_info():
|
||||
"""Set the verbosity to the `INFO` level."""
|
||||
return set_verbosity(INFO)
|
||||
|
||||
|
||||
def set_verbosity_warning():
|
||||
"""Set the verbosity to the `WARNING` level."""
|
||||
return set_verbosity(WARNING)
|
||||
|
||||
|
||||
def set_verbosity_debug():
|
||||
"""Set the verbosity to the `DEBUG` level."""
|
||||
return set_verbosity(DEBUG)
|
||||
|
||||
|
||||
def set_verbosity_error():
|
||||
"""Set the verbosity to the `ERROR` level."""
|
||||
return set_verbosity(ERROR)
|
||||
|
||||
|
||||
def disable_default_handler() -> None:
|
||||
"""Disable the default handler of the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert _default_handler is not None
|
||||
_get_library_root_logger().removeHandler(_default_handler)
|
||||
|
||||
|
||||
def enable_default_handler() -> None:
|
||||
"""Enable the default handler of the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert _default_handler is not None
|
||||
_get_library_root_logger().addHandler(_default_handler)
|
||||
|
||||
|
||||
def add_handler(handler: logging.Handler) -> None:
|
||||
"""adds a handler to the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert handler is not None
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
"""removes given handler from the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert handler is not None and handler not in _get_library_root_logger().handlers
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
def disable_propagation() -> None:
|
||||
"""
|
||||
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
||||
"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().propagate = False
|
||||
|
||||
|
||||
def enable_propagation() -> None:
|
||||
"""
|
||||
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
|
||||
prevent double logging if the root logger has been configured.
|
||||
"""
|
||||
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().propagate = True
|
||||
|
||||
|
||||
def enable_explicit_format() -> None:
|
||||
"""
|
||||
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
|
||||
```
|
||||
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
||||
```
|
||||
All handlers currently bound to the root logger are affected by this method.
|
||||
"""
|
||||
handlers = _get_library_root_logger().handlers
|
||||
|
||||
for handler in handlers:
|
||||
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
|
||||
def reset_format() -> None:
|
||||
"""
|
||||
Resets the formatting for HuggingFace Transformers's loggers.
|
||||
|
||||
All handlers currently bound to the root logger are affected by this method.
|
||||
"""
|
||||
handlers = _get_library_root_logger().handlers
|
||||
|
||||
for handler in handlers:
|
||||
handler.setFormatter(None)
|
||||
|
||||
|
||||
def warning_advice(self, *args, **kwargs):
|
||||
"""
|
||||
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
|
||||
warning will not be printed
|
||||
"""
|
||||
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
|
||||
if no_advisory_warnings:
|
||||
return
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.warning_advice = warning_advice
|
||||
|
||||
|
||||
class EmptyTqdm:
|
||||
"""Dummy tqdm which doesn't do anything."""
|
||||
|
||||
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||
self._iterator = args[0] if args else None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._iterator)
|
||||
|
||||
def __getattr__(self, _):
|
||||
"""Return empty function."""
|
||||
|
||||
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return
|
||||
|
||||
return empty_fn
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type_, value, traceback):
|
||||
return
|
||||
|
||||
|
||||
class _tqdm_cls:
|
||||
def __call__(self, *args, **kwargs):
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm(*args, **kwargs)
|
||||
else:
|
||||
return EmptyTqdm(*args, **kwargs)
|
||||
|
||||
def set_lock(self, *args, **kwargs):
|
||||
self._lock = None
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
||||
|
||||
def get_lock(self):
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm.get_lock()
|
||||
|
||||
|
||||
tqdm = _tqdm_cls()
|
||||
|
||||
|
||||
def is_progress_bar_enabled() -> bool:
|
||||
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
||||
global _tqdm_active
|
||||
return bool(_tqdm_active)
|
||||
|
||||
|
||||
def enable_progress_bar():
|
||||
"""Enable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = True
|
||||
|
||||
|
||||
def disable_progress_bar():
|
||||
"""Disable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = False
|
Loading…
Reference in New Issue