rename modeling code
This commit is contained in:
parent
80b865878c
commit
07b6d0e71e
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
__version__ = "0.0.1"
|
__version__ = "0.0.1"
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import ModelMixin
|
||||||
from .models.unet import UNetModel
|
from .models.unet import UNetModel
|
||||||
from .pipeline_utils import DiffusionPipeline
|
from .pipeline_utils import DiffusionPipeline
|
||||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Configuration base class and utilities."""
|
""" ConfigMixinuration base class and utilities."""
|
||||||
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
@ -44,7 +44,7 @@ logger = logging.get_logger(__name__)
|
||||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class ConfigMixin:
|
||||||
r"""
|
r"""
|
||||||
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
||||||
methods for loading/downloading/saving configurations.
|
methods for loading/downloading/saving configurations.
|
||||||
|
@ -71,7 +71,7 @@ class Config:
|
||||||
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||||
[`~Config.from_config`] class method.
|
[`~ConfigMixin.from_config`] class method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
save_directory (`str` or `os.PathLike`):
|
save_directory (`str` or `os.PathLike`):
|
||||||
|
@ -88,7 +88,7 @@ class Config:
|
||||||
output_config_file = os.path.join(save_directory, self.config_name)
|
output_config_file = os.path.join(save_directory, self.config_name)
|
||||||
|
|
||||||
self.to_json_file(output_config_file)
|
self.to_json_file(output_config_file)
|
||||||
logger.info(f"Configuration saved in {output_config_file}")
|
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_dict(
|
def get_config_dict(
|
||||||
|
|
|
@ -122,11 +122,11 @@ def _load_state_dict_into_model(model_to_load, state_dict):
|
||||||
return error_msgs
|
return error_msgs
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedModel(torch.nn.Module):
|
class ModelMixin(torch.nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Base class for all models.
|
Base class for all models.
|
||||||
|
|
||||||
[`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
|
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
|
||||||
downloading and saving models as well as a few methods common to all models to:
|
downloading and saving models as well as a few methods common to all models to:
|
||||||
|
|
||||||
- resize the input embeddings,
|
- resize the input embeddings,
|
||||||
|
@ -134,13 +134,13 @@ class PreTrainedModel(torch.nn.Module):
|
||||||
|
|
||||||
Class attributes (overridden by derived classes):
|
Class attributes (overridden by derived classes):
|
||||||
|
|
||||||
- **config_class** ([`Config`]) -- A subclass of [`Config`] to use as configuration class
|
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class
|
||||||
for this model architecture.
|
for this model architecture.
|
||||||
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
|
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
|
||||||
taking as arguments:
|
taking as arguments:
|
||||||
|
|
||||||
- **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
|
- **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint.
|
||||||
- **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
|
- **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model.
|
||||||
- **path** (`str`) -- A path to the TensorFlow checkpoint.
|
- **path** (`str`) -- A path to the TensorFlow checkpoint.
|
||||||
|
|
||||||
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
|
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
|
||||||
|
@ -163,7 +163,7 @@ class PreTrainedModel(torch.nn.Module):
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||||
`[`~PreTrainedModel.from_pretrained`]` class method.
|
`[`~ModelMixin.from_pretrained`]` class method.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
save_directory (`str` or `os.PathLike`):
|
save_directory (`str` or `os.PathLike`):
|
||||||
|
@ -231,20 +231,20 @@ class PreTrainedModel(torch.nn.Module):
|
||||||
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
||||||
user or organization name, like `dbmdz/bert-base-german-cased`.
|
user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||||
- A path to a *directory* containing model weights saved using
|
- A path to a *directory* containing model weights saved using
|
||||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
[`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`.
|
||||||
|
|
||||||
config (`Union[Config, str, os.PathLike]`, *optional*):
|
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
|
||||||
Can be either:
|
Can be either:
|
||||||
|
|
||||||
- an instance of a class derived from [`Config`],
|
- an instance of a class derived from [`ConfigMixin`],
|
||||||
- a string or path valid as input to [`~Config.from_pretrained`].
|
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
|
||||||
|
|
||||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
ConfigMixinuration for the model to use instead of an automatically loaded configuration. ConfigMixinuration can
|
||||||
be automatically loaded when:
|
be automatically loaded when:
|
||||||
|
|
||||||
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
||||||
model).
|
model).
|
||||||
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the
|
||||||
save directory.
|
save directory.
|
||||||
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
||||||
configuration JSON file named *config.json* is found in the directory.
|
configuration JSON file named *config.json* is found in the directory.
|
||||||
|
@ -295,7 +295,7 @@ class PreTrainedModel(torch.nn.Module):
|
||||||
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
||||||
already been done)
|
already been done)
|
||||||
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
||||||
initialization function ([`~Config.from_pretrained`]). Each key of `kwargs` that
|
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that
|
||||||
corresponds to a configuration attribute will be used to override said attribute with the
|
corresponds to a configuration attribute will be used to override said attribute with the
|
||||||
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
||||||
will be passed to the underlying model's `__init__` function.
|
will be passed to the underlying model's `__init__` function.
|
||||||
|
|
|
@ -29,8 +29,8 @@ from torchvision import transforms, utils
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..configuration_utils import Config
|
from ..configuration_utils import ConfigMixin
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
|
@ -175,7 +175,7 @@ class AttnBlock(nn.Module):
|
||||||
return x + h_
|
return x + h_
|
||||||
|
|
||||||
|
|
||||||
class UNetModel(PreTrainedModel, Config):
|
class UNetModel(ModelMixin, ConfigMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ch=128,
|
ch=128,
|
||||||
|
|
|
@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
|
||||||
# CHANGE to diffusers.utils
|
# CHANGE to diffusers.utils
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from .configuration_utils import Config
|
from .configuration_utils import ConfigMixin
|
||||||
|
|
||||||
|
|
||||||
INDEX_FILE = "diffusion_model.pt"
|
INDEX_FILE = "diffusion_model.pt"
|
||||||
|
@ -33,16 +33,16 @@ logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
LOADABLE_CLASSES = {
|
LOADABLE_CLASSES = {
|
||||||
"diffusers": {
|
"diffusers": {
|
||||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||||
"GaussianDDPMScheduler": ["save_config", "from_config"],
|
"GaussianDDPMScheduler": ["save_config", "from_config"],
|
||||||
},
|
},
|
||||||
"transformers": {
|
"transformers": {
|
||||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPipeline(Config):
|
class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
config_name = "model_index.json"
|
config_name = "model_index.json"
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..configuration_utils import Config
|
from ..configuration_utils import ConfigMixin
|
||||||
|
|
||||||
|
|
||||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||||
|
@ -24,7 +24,7 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
|
||||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||||
|
|
||||||
|
|
||||||
class GaussianDDPMScheduler(nn.Module, Config):
|
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||||
|
|
||||||
config_name = SAMPLING_CONFIG_NAME
|
config_name = SAMPLING_CONFIG_NAME
|
||||||
|
|
||||||
|
|
|
@ -40,13 +40,13 @@ _re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
|
||||||
|
|
||||||
|
|
||||||
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||||
"CLIPConfig",
|
"CLIPConfigMixin",
|
||||||
"DecisionTransformerConfig",
|
"DecisionTransformerConfigMixin",
|
||||||
"EncoderDecoderConfig",
|
"EncoderDecoderConfigMixin",
|
||||||
"RagConfig",
|
"RagConfigMixin",
|
||||||
"SpeechEncoderDecoderConfig",
|
"SpeechEncoderDecoderConfigMixin",
|
||||||
"VisionEncoderDecoderConfig",
|
"VisionEncoderDecoderConfigMixin",
|
||||||
"VisionTextDualEncoderConfig",
|
"VisionTextDualEncoderConfigMixin",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||||
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFModelMixin ?)
|
||||||
"TFRobertaForMultipleChoice", # TODO: fix
|
"TFRobertaForMultipleChoice", # TODO: fix
|
||||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||||
|
@ -271,7 +271,7 @@ def get_model_modules():
|
||||||
def get_models(module, include_pretrained=False):
|
def get_models(module, include_pretrained=False):
|
||||||
"""Get the objects in module that are models."""
|
"""Get the objects in module that are models."""
|
||||||
models = []
|
models = []
|
||||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin)
|
||||||
for attr_name in dir(module):
|
for attr_name in dir(module):
|
||||||
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
||||||
continue
|
continue
|
||||||
|
@ -372,7 +372,7 @@ def find_tested_models(test_file):
|
||||||
|
|
||||||
def check_models_are_tested(module, test_file):
|
def check_models_are_tested(module, test_file):
|
||||||
"""Check models defined in module are tested in test_file."""
|
"""Check models defined in module are tested in test_file."""
|
||||||
# XxxPreTrainedModel are not tested
|
# XxxModelMixin are not tested
|
||||||
defined_models = get_models(module)
|
defined_models = get_models(module)
|
||||||
tested_models = find_tested_models(test_file)
|
tested_models = find_tested_models(test_file)
|
||||||
if tested_models is None:
|
if tested_models is None:
|
||||||
|
@ -625,9 +625,9 @@ def ignore_undocumented(name):
|
||||||
# Constants uppercase are not documented.
|
# Constants uppercase are not documented.
|
||||||
if name.isupper():
|
if name.isupper():
|
||||||
return True
|
return True
|
||||||
# PreTrainedModels / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
|
# ModelMixins / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
|
||||||
if (
|
if (
|
||||||
name.endswith("PreTrainedModel")
|
name.endswith("ModelMixin")
|
||||||
or name.endswith("Decoder")
|
or name.endswith("Decoder")
|
||||||
or name.endswith("Encoder")
|
or name.endswith("Encoder")
|
||||||
or name.endswith("Layer")
|
or name.endswith("Layer")
|
||||||
|
|
|
@ -94,7 +94,7 @@ def get_model_table_from_auto_modules():
|
||||||
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
|
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
|
||||||
if code in config_maping_names
|
if code in config_maping_names
|
||||||
}
|
}
|
||||||
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
|
||||||
|
|
||||||
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
||||||
slow_tokenizers = collections.defaultdict(bool)
|
slow_tokenizers = collections.defaultdict(bool)
|
||||||
|
@ -190,7 +190,7 @@ def has_onnx(model_type):
|
||||||
for part in config_module.split(".")[1:]:
|
for part in config_module.split(".")[1:]:
|
||||||
module = getattr(module, part)
|
module = getattr(module, part)
|
||||||
config_name = config.__name__
|
config_name = config.__name__
|
||||||
onnx_config_name = config_name.replace("Config", "OnnxConfig")
|
onnx_config_name = config_name.replace("ConfigMixin", "OnnxConfigMixin")
|
||||||
return hasattr(module, onnx_config_name)
|
return hasattr(module, onnx_config_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue