[Bad dependencies] Fix imports (#1382)
* fix imports * better error * up * finish
This commit is contained in:
parent
1524122532
commit
35d8186172
|
@ -6,7 +6,14 @@ import numpy as np
|
|||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
is_flax_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -30,12 +37,16 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
|||
if is_transformers_available() and is_torch_available():
|
||||
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
|
||||
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
|
||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0"):
|
||||
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
|
||||
else:
|
||||
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
|
||||
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
from ...utils import is_torch_available, is_transformers_available
|
||||
from ...utils import is_torch_available, is_transformers_available, is_transformers_version
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0"):
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
|
||||
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
|
||||
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
|
||||
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
||||
else:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
)
|
||||
|
|
|
@ -33,6 +33,7 @@ from .import_utils import (
|
|||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
is_unidecode_available,
|
||||
requires_backends,
|
||||
)
|
||||
|
|
|
@ -303,6 +303,17 @@ def requires_backends(obj, backends):
|
|||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
if name in [
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
] and is_transformers_version("<", "4.25.0"):
|
||||
raise ImportError(
|
||||
f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install"
|
||||
" git+https://github.com/huggingface/transformers \n```"
|
||||
)
|
||||
|
||||
|
||||
class DummyObject(type):
|
||||
"""
|
||||
|
@ -347,3 +358,17 @@ def is_torch_version(operation: str, version: str):
|
|||
A string version of PyTorch
|
||||
"""
|
||||
return compare_versions(parse(_torch_version), operation, version)
|
||||
|
||||
|
||||
def is_transformers_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
Compares the current Transformers version to a given reference with an operation.
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A string version of PyTorch
|
||||
"""
|
||||
if not _transformers_available:
|
||||
return False
|
||||
return compare_versions(parse(_transformers_version), operation, version)
|
||||
|
|
Loading…
Reference in New Issue