[Loading] allow modules to be loaded in fp16 (#230)

This commit is contained in:
Patrick von Platen 2022-08-22 18:27:17 +02:00 committed by GitHub
parent 0ab948568d
commit db5fa43079
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 2 deletions

View File

@ -315,6 +315,7 @@ class ModelMixin(torch.nn.Module):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_auto_class = kwargs.pop("_from_auto", False)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@ -334,6 +335,12 @@ class ModelMixin(torch.nn.Module):
subfolder=subfolder,
**kwargs,
)
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}.")
elif torch_dtype is not None:
model = model.to(torch_dtype)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model

View File

@ -146,6 +146,7 @@ class DiffusionPipeline(ConfigMixin):
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@ -237,12 +238,16 @@ class DiffusionPipeline(ConfigMixin):
load_method = getattr(class_obj, load_method_name)
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name))
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder)
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)