[DeviceMap] Make sure stable diffusion can be loaded from older trans… (#860)
[DeviceMap] Make sure stable diffusion can be loaded from older transformers versiosn
This commit is contained in:
parent
93a81a3f5a
commit
2b7d4a5c21
|
@ -26,6 +26,7 @@ import torch
|
|||
import diffusers
|
||||
import PIL
|
||||
from huggingface_hub import snapshot_download
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
@ -45,6 +46,7 @@ from .utils import (
|
|||
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
|
@ -505,11 +507,14 @@ class DiffusionPipeline(ConfigMixin):
|
|||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
if (
|
||||
issubclass(class_obj, diffusers.ModelMixin)
|
||||
or is_transformers_available()
|
||||
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
):
|
||||
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
|
|
Loading…
Reference in New Issue