diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 5b136419..d3e6113d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -26,6 +26,7 @@ import torch import diffusers import PIL from huggingface_hub import snapshot_download +from packaging import version from PIL import Image from tqdm.auto import tqdm @@ -45,6 +46,7 @@ from .utils import ( if is_transformers_available(): + import transformers from transformers import PreTrainedModel @@ -505,11 +507,14 @@ class DiffusionPipeline(ConfigMixin): loading_kwargs["provider"] = provider loading_kwargs["sess_options"] = sess_options - if ( - issubclass(class_obj, diffusers.ModelMixin) - or is_transformers_available() + is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) + is_transformers_model = ( + is_transformers_available() and issubclass(class_obj, PreTrainedModel) - ): + and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") + ) + + if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map # check if the module is in a subdirectory