rename modeling code
This commit is contained in:
parent
80b865878c
commit
07b6d0e71e
|
@ -4,7 +4,7 @@
|
|||
|
||||
__version__ = "0.0.1"
|
||||
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Configuration base class and utilities."""
|
||||
""" ConfigMixinuration base class and utilities."""
|
||||
|
||||
|
||||
import copy
|
||||
|
@ -44,7 +44,7 @@ logger = logging.get_logger(__name__)
|
|||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
|
||||
class Config:
|
||||
class ConfigMixin:
|
||||
r"""
|
||||
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
||||
methods for loading/downloading/saving configurations.
|
||||
|
@ -71,7 +71,7 @@ class Config:
|
|||
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~Config.from_config`] class method.
|
||||
[`~ConfigMixin.from_config`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
|
@ -88,7 +88,7 @@ class Config:
|
|||
output_config_file = os.path.join(save_directory, self.config_name)
|
||||
|
||||
self.to_json_file(output_config_file)
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
|
|
|
@ -122,11 +122,11 @@ def _load_state_dict_into_model(model_to_load, state_dict):
|
|||
return error_msgs
|
||||
|
||||
|
||||
class PreTrainedModel(torch.nn.Module):
|
||||
class ModelMixin(torch.nn.Module):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
[`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
|
||||
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
|
||||
downloading and saving models as well as a few methods common to all models to:
|
||||
|
||||
- resize the input embeddings,
|
||||
|
@ -134,13 +134,13 @@ class PreTrainedModel(torch.nn.Module):
|
|||
|
||||
Class attributes (overridden by derived classes):
|
||||
|
||||
- **config_class** ([`Config`]) -- A subclass of [`Config`] to use as configuration class
|
||||
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class
|
||||
for this model architecture.
|
||||
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
|
||||
taking as arguments:
|
||||
|
||||
- **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
|
||||
- **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
|
||||
- **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint.
|
||||
- **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model.
|
||||
- **path** (`str`) -- A path to the TensorFlow checkpoint.
|
||||
|
||||
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
|
||||
|
@ -163,7 +163,7 @@ class PreTrainedModel(torch.nn.Module):
|
|||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~PreTrainedModel.from_pretrained`]` class method.
|
||||
`[`~ModelMixin.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
|
@ -231,20 +231,20 @@ class PreTrainedModel(torch.nn.Module):
|
|||
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
||||
user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
[`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
|
||||
config (`Union[Config, str, os.PathLike]`, *optional*):
|
||||
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- an instance of a class derived from [`Config`],
|
||||
- a string or path valid as input to [`~Config.from_pretrained`].
|
||||
- an instance of a class derived from [`ConfigMixin`],
|
||||
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
|
||||
|
||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||
ConfigMixinuration for the model to use instead of an automatically loaded configuration. ConfigMixinuration can
|
||||
be automatically loaded when:
|
||||
|
||||
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
||||
model).
|
||||
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
||||
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the
|
||||
save directory.
|
||||
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
||||
configuration JSON file named *config.json* is found in the directory.
|
||||
|
@ -295,7 +295,7 @@ class PreTrainedModel(torch.nn.Module):
|
|||
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
||||
initialization function ([`~Config.from_pretrained`]). Each key of `kwargs` that
|
||||
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that
|
||||
corresponds to a configuration attribute will be used to override said attribute with the
|
||||
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
||||
will be passed to the underlying model's `__init__` function.
|
||||
|
|
|
@ -29,8 +29,8 @@ from torchvision import transforms, utils
|
|||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import Config
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
@ -175,7 +175,7 @@ class AttnBlock(nn.Module):
|
|||
return x + h_
|
||||
|
||||
|
||||
class UNetModel(PreTrainedModel, Config):
|
||||
class UNetModel(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
|
|
|
@ -22,7 +22,7 @@ from huggingface_hub import snapshot_download
|
|||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import logging
|
||||
|
||||
from .configuration_utils import Config
|
||||
from .configuration_utils import ConfigMixin
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_model.pt"
|
||||
|
@ -33,16 +33,16 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"GaussianDDPMScheduler": ["save_config", "from_config"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DiffusionPipeline(Config):
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
config_name = "model_index.json"
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import Config
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
@ -24,7 +24,7 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
|
|||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
class GaussianDDPMScheduler(nn.Module, Config):
|
||||
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
|
|
|
@ -40,13 +40,13 @@ _re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
|
|||
|
||||
|
||||
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"CLIPConfig",
|
||||
"DecisionTransformerConfig",
|
||||
"EncoderDecoderConfig",
|
||||
"RagConfig",
|
||||
"SpeechEncoderDecoderConfig",
|
||||
"VisionEncoderDecoderConfig",
|
||||
"VisionTextDualEncoderConfig",
|
||||
"CLIPConfigMixin",
|
||||
"DecisionTransformerConfigMixin",
|
||||
"EncoderDecoderConfigMixin",
|
||||
"RagConfigMixin",
|
||||
"SpeechEncoderDecoderConfigMixin",
|
||||
"VisionEncoderDecoderConfigMixin",
|
||||
"VisionTextDualEncoderConfigMixin",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
|||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFModelMixin ?)
|
||||
"TFRobertaForMultipleChoice", # TODO: fix
|
||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
|
@ -271,7 +271,7 @@ def get_model_modules():
|
|||
def get_models(module, include_pretrained=False):
|
||||
"""Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||
model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin)
|
||||
for attr_name in dir(module):
|
||||
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
||||
continue
|
||||
|
@ -372,7 +372,7 @@ def find_tested_models(test_file):
|
|||
|
||||
def check_models_are_tested(module, test_file):
|
||||
"""Check models defined in module are tested in test_file."""
|
||||
# XxxPreTrainedModel are not tested
|
||||
# XxxModelMixin are not tested
|
||||
defined_models = get_models(module)
|
||||
tested_models = find_tested_models(test_file)
|
||||
if tested_models is None:
|
||||
|
@ -625,9 +625,9 @@ def ignore_undocumented(name):
|
|||
# Constants uppercase are not documented.
|
||||
if name.isupper():
|
||||
return True
|
||||
# PreTrainedModels / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
|
||||
# ModelMixins / Encoders / Decoders / Layers / Embeddings / Attention are not documented.
|
||||
if (
|
||||
name.endswith("PreTrainedModel")
|
||||
name.endswith("ModelMixin")
|
||||
or name.endswith("Decoder")
|
||||
or name.endswith("Encoder")
|
||||
or name.endswith("Layer")
|
||||
|
|
|
@ -94,7 +94,7 @@ def get_model_table_from_auto_modules():
|
|||
for code, name in transformers_module.MODEL_NAMES_MAPPING.items()
|
||||
if code in config_maping_names
|
||||
}
|
||||
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
||||
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
|
||||
|
||||
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
||||
slow_tokenizers = collections.defaultdict(bool)
|
||||
|
@ -190,7 +190,7 @@ def has_onnx(model_type):
|
|||
for part in config_module.split(".")[1:]:
|
||||
module = getattr(module, part)
|
||||
config_name = config.__name__
|
||||
onnx_config_name = config_name.replace("Config", "OnnxConfig")
|
||||
onnx_config_name = config_name.replace("ConfigMixin", "OnnxConfigMixin")
|
||||
return hasattr(module, onnx_config_name)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue