rename modeling code

This commit is contained in:
Patrick von Platen 2022-06-07 10:35:53 +02:00
parent 80b865878c
commit 07b6d0e71e
9 changed files with 41 additions and 41 deletions

View File

@ -4,7 +4,7 @@
__version__ = "0.0.1"
from .modeling_utils import PreTrainedModel
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .pipeline_utils import DiffusionPipeline
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""
""" ConfigMixinuration base class and utilities."""
import copy
@ -44,7 +44,7 @@ logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
class Config:
class ConfigMixin:
r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
@ -71,7 +71,7 @@ class Config:
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~Config.from_config`] class method.
[`~ConfigMixin.from_config`] class method.
Args:
save_directory (`str` or `os.PathLike`):
@ -88,7 +88,7 @@ class Config:
output_config_file = os.path.join(save_directory, self.config_name)
self.to_json_file(output_config_file)
logger.info(f"Configuration saved in {output_config_file}")
logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod
def get_config_dict(

View File

@ -122,11 +122,11 @@ def _load_state_dict_into_model(model_to_load, state_dict):
return error_msgs
class PreTrainedModel(torch.nn.Module):
class ModelMixin(torch.nn.Module):
r"""
Base class for all models.
[`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models as well as a few methods common to all models to:
- resize the input embeddings,
@ -134,13 +134,13 @@ class PreTrainedModel(torch.nn.Module):
Class attributes (overridden by derived classes):
- **config_class** ([`Config`]) -- A subclass of [`Config`] to use as configuration class
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class
for this model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
- **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
- **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model.
- **path** (`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
@ -163,7 +163,7 @@ class PreTrainedModel(torch.nn.Module):
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~PreTrainedModel.from_pretrained`]` class method.
`[`~ModelMixin.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
@ -231,20 +231,20 @@ class PreTrainedModel(torch.nn.Module):
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
[`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`.
config (`Union[Config, str, os.PathLike]`, *optional*):
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
Can be either:
- an instance of a class derived from [`Config`],
- a string or path valid as input to [`~Config.from_pretrained`].
- an instance of a class derived from [`ConfigMixin`],
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
ConfigMixinuration for the model to use instead of an automatically loaded configuration. ConfigMixinuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
model).
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the
save directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
@ -295,7 +295,7 @@ class PreTrainedModel(torch.nn.Module):
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~Config.from_pretrained`]). Each key of `kwargs` that
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that
corresponds to a configuration attribute will be used to override said attribute with the
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
will be passed to the underlying model's `__init__` function.

View File

@ -29,8 +29,8 @@ from torchvision import transforms, utils
from PIL import Image
from tqdm import tqdm
from ..configuration_utils import Config
from ..modeling_utils import PreTrainedModel
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
def get_timestep_embedding(timesteps, embedding_dim):
@ -175,7 +175,7 @@ class AttnBlock(nn.Module):
return x + h_
class UNetModel(PreTrainedModel, Config):
class UNetModel(ModelMixin, ConfigMixin):
def __init__(
self,
ch=128,

View File

@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
# CHANGE to diffusers.utils
from transformers.utils import logging
from .configuration_utils import Config
from .configuration_utils import ConfigMixin
INDEX_FILE = "diffusion_model.pt"
@ -33,16 +33,16 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"ModelMixin": ["save_pretrained", "from_pretrained"],
"GaussianDDPMScheduler": ["save_config", "from_config"],
},
"transformers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"ModelMixin": ["save_pretrained", "from_pretrained"],
},
}
class DiffusionPipeline(Config):
class DiffusionPipeline(ConfigMixin):
config_name = "model_index.json"

View File

@ -14,7 +14,7 @@
import torch
from torch import nn
from ..configuration_utils import Config
from ..configuration_utils import ConfigMixin
SAMPLING_CONFIG_NAME = "scheduler_config.json"
@ -24,7 +24,7 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
class GaussianDDPMScheduler(nn.Module, Config):
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME

View File

@ -40,13 +40,13 @@ _re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"CLIPConfig",
"DecisionTransformerConfig",
"EncoderDecoderConfig",
"RagConfig",
"SpeechEncoderDecoderConfig",
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
"CLIPConfigMixin",
"DecisionTransformerConfigMixin",
"EncoderDecoderConfigMixin",
"RagConfigMixin",
"SpeechEncoderDecoderConfigMixin",
"VisionEncoderDecoderConfigMixin",
"VisionTextDualEncoderConfigMixin",
}

View File

@ -87,7 +87,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"ReformerForMaskedLM", # Needs to be setup as decoder.
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFModelMixin ?)
"TFRobertaForMultipleChoice", # TODO: fix
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
"SeparableConv1D", # Building part of bigger (tested) model.
@ -271,7 +271,7 @@ def get_model_modules():
def get_models(module, include_pretrained=False):
"""Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin)
for attr_name in dir(module):
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
continue
@ -372,7 +372,7 @@ def find_tested_models(test_file):
def check_models_are_tested(module, test_file):
"""Check models defined in module are tested in test_file."""
# XxxPreTrainedModel are not tested
# XxxModelMixin are not tested
defined_models = get_models(module)
tested_models = find_tested_models(test_file)
if tested_models is None:
@ -625,9 +625,9 @@ def ignore_undocumented(name):
# Constants uppercase are not documented.
if name.isupper():
return True
# PreTrainedModels / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
# ModelMixins / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
if (
name.endswith("PreTrainedModel")
name.endswith("ModelMixin")
or name.endswith("Decoder")
or name.endswith("Encoder")
or name.endswith("Layer")

View File

@ -94,7 +94,7 @@ def get_model_table_from_auto_modules():
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
if code in config_maping_names
}
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers = collections.defaultdict(bool)
@ -190,7 +190,7 @@ def has_onnx(model_type):
for part in config_module.split(".")[1:]:
module = getattr(module, part)
config_name = config.__name__
onnx_config_name = config_name.replace("Config", "OnnxConfig")
onnx_config_name = config_name.replace("ConfigMixin", "OnnxConfigMixin")
return hasattr(module, onnx_config_name)