Make repo structure consistent (#1862)
* move files a bit * more refactors * fix more * more fixes * fix more onnx * make style * upload * fix * up * fix more * up again * up * small fix * Update src/diffusers/__init__.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * correct Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
ab0e92fdc8
commit
29b2c93c90
|
@ -155,9 +155,9 @@ adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`funct
|
||||||
function to be in the main package.
|
function to be in the main package.
|
||||||
|
|
||||||
If you want to create a link to some internal class or function, you need to
|
If you want to create a link to some internal class or function, you need to
|
||||||
provide its path. For instance: \[\`pipeline_utils.ImagePipelineOutput\`\]. This will be converted into a link with
|
provide its path. For instance: \[\`pipelines.ImagePipelineOutput\`\]. This will be converted into a link with
|
||||||
`pipeline_utils.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are
|
`pipelines.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are
|
||||||
linking to in the description, add a ~: \[\`~pipeline_utils.ImagePipelineOutput\`\] will generate a link with `ImagePipelineOutput` in the description.
|
linking to in the description, add a ~: \[\`~pipelines.ImagePipelineOutput\`\] will generate a link with `ImagePipelineOutput` in the description.
|
||||||
|
|
||||||
The same works for methods so you can either use \[\`XXXClass.method\`\] or \[~\`XXXClass.method\`\].
|
The same works for methods so you can either use \[\`XXXClass.method\`\] or \[~\`XXXClass.method\`\].
|
||||||
|
|
||||||
|
|
|
@ -39,4 +39,4 @@ Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrain
|
||||||
## ImagePipelineOutput
|
## ImagePipelineOutput
|
||||||
By default diffusion pipelines return an object of class
|
By default diffusion pipelines return an object of class
|
||||||
|
|
||||||
[[autodoc]] pipeline_utils.ImagePipelineOutput
|
[[autodoc]] pipelines.ImagePipelineOutput
|
||||||
|
|
|
@ -41,13 +41,13 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||||
[[autodoc]] models.vae.DecoderOutput
|
[[autodoc]] models.vae.DecoderOutput
|
||||||
|
|
||||||
## VQEncoderOutput
|
## VQEncoderOutput
|
||||||
[[autodoc]] models.vae.VQEncoderOutput
|
[[autodoc]] models.vq_model.VQEncoderOutput
|
||||||
|
|
||||||
## VQModel
|
## VQModel
|
||||||
[[autodoc]] VQModel
|
[[autodoc]] VQModel
|
||||||
|
|
||||||
## AutoencoderKLOutput
|
## AutoencoderKLOutput
|
||||||
[[autodoc]] models.vae.AutoencoderKLOutput
|
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
|
||||||
|
|
||||||
## AutoencoderKL
|
## AutoencoderKL
|
||||||
[[autodoc]] AutoencoderKL
|
[[autodoc]] AutoencoderKL
|
||||||
|
|
|
@ -25,7 +25,7 @@ pipeline = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32")
|
||||||
outputs = pipeline()
|
outputs = pipeline()
|
||||||
```
|
```
|
||||||
|
|
||||||
The `outputs` object is a [`~pipeline_utils.ImagePipelineOutput`], as we can see in the
|
The `outputs` object is a [`~pipelines.ImagePipelineOutput`], as we can see in the
|
||||||
documentation of that class below, it means it has an image attribute.
|
documentation of that class below, it means it has an image attribute.
|
||||||
|
|
||||||
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you will get `None`:
|
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you will get `None`:
|
||||||
|
|
|
@ -2,8 +2,7 @@ from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
|
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
|
||||||
from einops import rearrange, reduce
|
from einops import rearrange, reduce
|
||||||
|
|
|
@ -5,13 +5,7 @@ from typing import Dict, List, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, __version__
|
from diffusers import DiffusionPipeline, __version__
|
||||||
from diffusers.pipeline_utils import (
|
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, WEIGHTS_NAME
|
||||||
CONFIG_NAME,
|
|
||||||
DIFFUSERS_CACHE,
|
|
||||||
ONNX_WEIGHTS_NAME,
|
|
||||||
SCHEDULER_CONFIG_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
)
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,14 +17,10 @@ from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.utils import is_accelerate_available
|
from diffusers import DiffusionPipeline
|
||||||
from packaging import version
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from diffusers.schedulers import (
|
||||||
from ...configuration_utils import FrozenDict
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import (
|
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
|
@ -32,6 +28,10 @@ from ...schedulers import (
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
|
from diffusers.utils import is_accelerate_available
|
||||||
|
from packaging import version
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
|
|
@ -12,8 +12,8 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
|
|
@ -5,9 +5,9 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
|
|
@ -6,9 +6,9 @@ from typing import Callable, List, Optional, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
|
|
@ -7,8 +7,7 @@ import torch
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
|
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
|
||||||
from diffusers.onnx_utils import OnnxRuntimeModel
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
@ -16,7 +15,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers.onnx_utils import ORT_TO_NP_TYPE
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ORT_TO_NP_TYPE = {
|
ORT_TO_NP_TYPE = {
|
||||||
"tensor(bool)": np.bool_,
|
"tensor(bool)": np.bool_,
|
||||||
|
|
|
@ -3,9 +3,9 @@ from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
|
|
@ -18,8 +18,7 @@ from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import LMSDiscreteScheduler
|
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.utils import is_accelerate_available, logging
|
from diffusers.utils import is_accelerate_available, logging
|
||||||
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
||||||
|
|
|
@ -6,8 +6,8 @@ from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
|
|
@ -3,9 +3,9 @@ from typing import Callable, List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
|
|
@ -7,9 +7,9 @@ from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
|
|
@ -21,8 +21,7 @@ import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
from diffusers import OnnxStableDiffusionPipeline, StableDiffusionPipeline
|
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline
|
||||||
from diffusers.onnx_utils import OnnxRuntimeModel
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
__version__ = "0.12.0.dev0"
|
__version__ = "0.12.0.dev0"
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .onnx_utils import OnnxRuntimeModel
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
OptionalDependencyNotAvailable,
|
OptionalDependencyNotAvailable,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
|
@ -18,15 +17,23 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_onnx_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils.dummy_onnx_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .pipelines import OnnxRuntimeModel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
from .utils.dummy_pt_objects import * # noqa F403
|
from .utils.dummy_pt_objects import * # noqa F403
|
||||||
else:
|
else:
|
||||||
from .modeling_utils import ModelMixin
|
|
||||||
from .models import (
|
from .models import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
|
ModelMixin,
|
||||||
PriorTransformer,
|
PriorTransformer,
|
||||||
Transformer2DModel,
|
Transformer2DModel,
|
||||||
UNet1DModel,
|
UNet1DModel,
|
||||||
|
@ -43,11 +50,13 @@ else:
|
||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
get_scheduler,
|
get_scheduler,
|
||||||
)
|
)
|
||||||
from .pipeline_utils import DiffusionPipeline
|
|
||||||
from .pipelines import (
|
from .pipelines import (
|
||||||
|
AudioPipelineOutput,
|
||||||
DanceDiffusionPipeline,
|
DanceDiffusionPipeline,
|
||||||
DDIMPipeline,
|
DDIMPipeline,
|
||||||
DDPMPipeline,
|
DDPMPipeline,
|
||||||
|
DiffusionPipeline,
|
||||||
|
ImagePipelineOutput,
|
||||||
KarrasVePipeline,
|
KarrasVePipeline,
|
||||||
LDMPipeline,
|
LDMPipeline,
|
||||||
LDMSuperResolutionPipeline,
|
LDMSuperResolutionPipeline,
|
||||||
|
@ -150,10 +159,10 @@ try:
|
||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
from .utils.dummy_flax_objects import * # noqa F403
|
from .utils.dummy_flax_objects import * # noqa F403
|
||||||
else:
|
else:
|
||||||
from .modeling_flax_utils import FlaxModelMixin
|
from .models.modeling_flax_utils import FlaxModelMixin
|
||||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||||
from .models.vae_flax import FlaxAutoencoderKL
|
from .models.vae_flax import FlaxAutoencoderKL
|
||||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
from .pipelines import FlaxDiffusionPipeline
|
||||||
from .schedulers import (
|
from .schedulers import (
|
||||||
FlaxDDIMScheduler,
|
FlaxDDIMScheduler,
|
||||||
FlaxDDPMScheduler,
|
FlaxDDPMScheduler,
|
||||||
|
|
|
@ -18,7 +18,7 @@ import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from ...models.unet_1d import UNet1DModel
|
from ...models.unet_1d import UNet1DModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipelines import DiffusionPipeline
|
||||||
from ...utils.dummy_pt_objects import DDPMScheduler
|
from ...utils.dummy_pt_objects import DDPMScheduler
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,12 +16,15 @@ from ..utils import is_flax_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .attention import Transformer2DModel
|
from .autoencoder_kl import AutoencoderKL
|
||||||
|
from .dual_transformer_2d import DualTransformer2DModel
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
from .prior_transformer import PriorTransformer
|
from .prior_transformer import PriorTransformer
|
||||||
|
from .transformer_2d import Transformer2DModel
|
||||||
from .unet_1d import UNet1DModel
|
from .unet_1d import UNet1DModel
|
||||||
from .unet_2d import UNet2DModel
|
from .unet_2d import UNet2DModel
|
||||||
from .unet_2d_condition import UNet2DConditionModel
|
from .unet_2d_condition import UNet2DConditionModel
|
||||||
from .vae import AutoencoderKL, VQModel
|
from .vq_model import VQModel
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||||
|
|
|
@ -20,11 +20,11 @@ import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..modeling_utils import ModelMixin
|
|
||||||
from ..models.embeddings import ImagePositionalEmbeddings
|
from ..models.embeddings import ImagePositionalEmbeddings
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from ..utils.import_utils import is_xformers_available
|
from ..utils.import_utils import is_xformers_available
|
||||||
from .cross_attention import CrossAttention
|
from .cross_attention import CrossAttention
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -0,0 +1,177 @@
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..utils import BaseOutput
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
|
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoencoderKLOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output of AutoencoderKL encoding method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent_dist (`DiagonalGaussianDistribution`):
|
||||||
|
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
||||||
|
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent_dist: "DiagonalGaussianDistribution"
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||||
|
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
|
||||||
|
and Max Welling.
|
||||||
|
|
||||||
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||||
|
implements for all the model (such as downloading or saving, etc.)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||||
|
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||||
|
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||||
|
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
||||||
|
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||||
|
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
||||||
|
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||||
|
obj:`(64,)`): Tuple of block output channels.
|
||||||
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||||
|
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
|
||||||
|
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||||
|
"""
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 3,
|
||||||
|
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||||
|
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||||
|
block_out_channels: Tuple[int] = (64,),
|
||||||
|
layers_per_block: int = 1,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
latent_channels: int = 4,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
sample_size: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# pass init params to Encoder
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=latent_channels,
|
||||||
|
down_block_types=down_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
act_fn=act_fn,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
double_z=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pass init params to Decoder
|
||||||
|
self.decoder = Decoder(
|
||||||
|
in_channels=latent_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
up_block_types=up_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||||
|
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||||
|
self.use_slicing = False
|
||||||
|
|
||||||
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (posterior,)
|
||||||
|
|
||||||
|
return AutoencoderKLOutput(latent_dist=posterior)
|
||||||
|
|
||||||
|
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
dec = self.decoder(z)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
|
def enable_slicing(self):
|
||||||
|
r"""
|
||||||
|
Enable sliced VAE decoding.
|
||||||
|
|
||||||
|
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||||
|
steps. This is useful to save some memory and allow larger batch sizes.
|
||||||
|
"""
|
||||||
|
self.use_slicing = True
|
||||||
|
|
||||||
|
def disable_slicing(self):
|
||||||
|
r"""
|
||||||
|
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
||||||
|
decoding in one step.
|
||||||
|
"""
|
||||||
|
self.use_slicing = False
|
||||||
|
|
||||||
|
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
if self.use_slicing and z.shape[0] > 1:
|
||||||
|
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||||
|
decoded = torch.cat(decoded_slices)
|
||||||
|
else:
|
||||||
|
decoded = self._decode(z).sample
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (decoded,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=decoded)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
sample_posterior: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`): Input sample.
|
||||||
|
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to sample from the posterior.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
x = sample
|
||||||
|
posterior = self.encode(x).latent_dist
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample(generator=generator)
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z).sample
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
|
@ -0,0 +1,151 @@
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
class DualTransformer2DModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||||
|
in_channels (`int`, *optional*):
|
||||||
|
Pass if the input is continuous. The number of channels in the input and output.
|
||||||
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||||
|
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||||
|
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||||
|
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||||
|
`ImagePositionalEmbeddings`.
|
||||||
|
num_vector_embeds (`int`, *optional*):
|
||||||
|
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||||
|
Includes the class for the masked latent pixel.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||||
|
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||||
|
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||||
|
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||||
|
up to but not more than steps than `num_embeds_ada_norm`.
|
||||||
|
attention_bias (`bool`, *optional*):
|
||||||
|
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
attention_head_dim: int = 88,
|
||||||
|
in_channels: Optional[int] = None,
|
||||||
|
num_layers: int = 1,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
cross_attention_dim: Optional[int] = None,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
sample_size: Optional[int] = None,
|
||||||
|
num_vector_embeds: Optional[int] = None,
|
||||||
|
activation_fn: str = "geglu",
|
||||||
|
num_embeds_ada_norm: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.transformers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Transformer2DModel(
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_layers=num_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
sample_size=sample_size,
|
||||||
|
num_vector_embeds=num_vector_embeds,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||||
|
)
|
||||||
|
for _ in range(2)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Variables that can be set by a pipeline:
|
||||||
|
|
||||||
|
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||||
|
self.mix_ratio = 0.5
|
||||||
|
|
||||||
|
# The shape of `encoder_hidden_states` is expected to be
|
||||||
|
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
||||||
|
self.condition_lengths = [77, 257]
|
||||||
|
|
||||||
|
# Which transformer to use to encode which condition.
|
||||||
|
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||||
|
self.transformer_index_for_condition = [1, 0]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states,
|
||||||
|
timestep=None,
|
||||||
|
attention_mask=None,
|
||||||
|
cross_attention_kwargs=None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||||
|
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||||
|
hidden_states
|
||||||
|
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||||
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||||
|
self-attention.
|
||||||
|
timestep ( `torch.long`, *optional*):
|
||||||
|
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*):
|
||||||
|
Optional attention mask to be applied in CrossAttention
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
||||||
|
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
||||||
|
tensor.
|
||||||
|
"""
|
||||||
|
input_states = hidden_states
|
||||||
|
|
||||||
|
encoded_states = []
|
||||||
|
tokens_start = 0
|
||||||
|
# attention_mask is not used yet
|
||||||
|
for i in range(2):
|
||||||
|
# for each of the two transformers, pass the corresponding condition tokens
|
||||||
|
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||||
|
transformer_index = self.transformer_index_for_condition[i]
|
||||||
|
encoded_state = self.transformers[transformer_index](
|
||||||
|
input_states,
|
||||||
|
encoder_hidden_states=condition_state,
|
||||||
|
timestep=timestep,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
encoded_states.append(encoded_state - input_states)
|
||||||
|
tokens_start += self.condition_lengths[i]
|
||||||
|
|
||||||
|
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||||
|
output_states = output_states + input_states
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (output_states,)
|
||||||
|
|
||||||
|
return Transformer2DModelOutput(sample=output_states)
|
|
@ -19,7 +19,7 @@ import jax.numpy as jnp
|
||||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from .utils import logging
|
from ..utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
|
@ -27,9 +27,8 @@ from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|
||||||
from . import __version__, is_torch_available
|
from .. import __version__, is_torch_available
|
||||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
from ..utils import (
|
||||||
from .utils import (
|
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DIFFUSERS_CACHE,
|
DIFFUSERS_CACHE,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
|
@ -37,6 +36,7 @@ from .utils import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
|
@ -26,11 +26,11 @@ from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|
||||||
from . import __version__
|
from .. import __version__
|
||||||
from .hub_utils import HF_HUB_OFFLINE
|
from ..utils import (
|
||||||
from .utils import (
|
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DIFFUSERS_CACHE,
|
DIFFUSERS_CACHE,
|
||||||
|
HF_HUB_OFFLINE,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
SAFETENSORS_WEIGHTS_NAME,
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
@ -149,7 +149,7 @@ class ModelMixin(torch.nn.Module):
|
||||||
and saving models.
|
and saving models.
|
||||||
|
|
||||||
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
||||||
[`~modeling_utils.ModelMixin.save_pretrained`].
|
[`~models.ModelMixin.save_pretrained`].
|
||||||
"""
|
"""
|
||||||
config_name = CONFIG_NAME
|
config_name = CONFIG_NAME
|
||||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
||||||
|
@ -231,7 +231,7 @@ class ModelMixin(torch.nn.Module):
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||||
`[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
|
`[`~models.ModelMixin.from_pretrained`]` class method.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
save_directory (`str` or `os.PathLike`):
|
save_directory (`str` or `os.PathLike`):
|
|
@ -6,10 +6,10 @@ import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..modeling_utils import ModelMixin
|
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from .attention import BasicTransformerBlock
|
from .attention import BasicTransformerBlock
|
||||||
from .embeddings import TimestepEmbedding, Timesteps
|
from .embeddings import TimestepEmbedding, Timesteps
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -0,0 +1,244 @@
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..models.embeddings import ImagePositionalEmbeddings
|
||||||
|
from ..utils import BaseOutput
|
||||||
|
from .attention import BasicTransformerBlock
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Transformer2DModelOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||||
|
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
||||||
|
for the unnoised latent pixels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sample: torch.FloatTensor
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
||||||
|
embeddings) inputs.
|
||||||
|
|
||||||
|
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
||||||
|
transformer action. Finally, reshape to image.
|
||||||
|
|
||||||
|
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
||||||
|
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
||||||
|
classes of unnoised image.
|
||||||
|
|
||||||
|
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
||||||
|
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||||
|
in_channels (`int`, *optional*):
|
||||||
|
Pass if the input is continuous. The number of channels in the input and output.
|
||||||
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||||
|
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||||
|
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||||
|
`ImagePositionalEmbeddings`.
|
||||||
|
num_vector_embeds (`int`, *optional*):
|
||||||
|
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||||
|
Includes the class for the masked latent pixel.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||||
|
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||||
|
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||||
|
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||||
|
up to but not more than steps than `num_embeds_ada_norm`.
|
||||||
|
attention_bias (`bool`, *optional*):
|
||||||
|
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
attention_head_dim: int = 88,
|
||||||
|
in_channels: Optional[int] = None,
|
||||||
|
num_layers: int = 1,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
cross_attention_dim: Optional[int] = None,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
sample_size: Optional[int] = None,
|
||||||
|
num_vector_embeds: Optional[int] = None,
|
||||||
|
activation_fn: str = "geglu",
|
||||||
|
num_embeds_ada_norm: Optional[int] = None,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
only_cross_attention: bool = False,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_linear_projection = use_linear_projection
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
|
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||||
|
# Define whether input is continuous or discrete depending on configuration
|
||||||
|
self.is_input_continuous = in_channels is not None
|
||||||
|
self.is_input_vectorized = num_vector_embeds is not None
|
||||||
|
|
||||||
|
if self.is_input_continuous and self.is_input_vectorized:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||||
|
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||||
|
)
|
||||||
|
elif not self.is_input_continuous and not self.is_input_vectorized:
|
||||||
|
raise ValueError(
|
||||||
|
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||||
|
" sure that either `in_channels` or `num_vector_embeds` is not None."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Define input layers
|
||||||
|
if self.is_input_continuous:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
if use_linear_projection:
|
||||||
|
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||||
|
else:
|
||||||
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||||
|
elif self.is_input_vectorized:
|
||||||
|
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||||
|
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||||
|
|
||||||
|
self.height = sample_size
|
||||||
|
self.width = sample_size
|
||||||
|
self.num_vector_embeds = num_vector_embeds
|
||||||
|
self.num_latent_pixels = self.height * self.width
|
||||||
|
|
||||||
|
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||||
|
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Define transformers blocks
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock(
|
||||||
|
inner_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
only_cross_attention=only_cross_attention,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
for d in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Define output layers
|
||||||
|
if self.is_input_continuous:
|
||||||
|
if use_linear_projection:
|
||||||
|
self.proj_out = nn.Linear(in_channels, inner_dim)
|
||||||
|
else:
|
||||||
|
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
elif self.is_input_vectorized:
|
||||||
|
self.norm_out = nn.LayerNorm(inner_dim)
|
||||||
|
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
timestep=None,
|
||||||
|
cross_attention_kwargs=None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||||
|
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||||
|
hidden_states
|
||||||
|
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||||
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||||
|
self-attention.
|
||||||
|
timestep ( `torch.long`, *optional*):
|
||||||
|
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
||||||
|
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
||||||
|
tensor.
|
||||||
|
"""
|
||||||
|
# 1. Input
|
||||||
|
if self.is_input_continuous:
|
||||||
|
batch, channel, height, width = hidden_states.shape
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
if not self.use_linear_projection:
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
inner_dim = hidden_states.shape[1]
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||||
|
else:
|
||||||
|
inner_dim = hidden_states.shape[1]
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
elif self.is_input_vectorized:
|
||||||
|
hidden_states = self.latent_image_embedding(hidden_states)
|
||||||
|
|
||||||
|
# 2. Blocks
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
timestep=timestep,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Output
|
||||||
|
if self.is_input_continuous:
|
||||||
|
if not self.use_linear_projection:
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
output = hidden_states + residual
|
||||||
|
elif self.is_input_vectorized:
|
||||||
|
hidden_states = self.norm_out(hidden_states)
|
||||||
|
logits = self.out(hidden_states)
|
||||||
|
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||||
|
logits = logits.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# log(p(x_0))
|
||||||
|
output = F.log_softmax(logits.double(), dim=1).float()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (output,)
|
||||||
|
|
||||||
|
return Transformer2DModelOutput(sample=output)
|
|
@ -19,9 +19,9 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..modeling_utils import ModelMixin
|
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,9 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..modeling_utils import ModelMixin
|
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,10 @@ import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..modeling_utils import ModelMixin
|
|
||||||
from ..utils import BaseOutput, logging
|
from ..utils import BaseOutput, logging
|
||||||
from .cross_attention import AttnProcessor
|
from .cross_attention import AttnProcessor
|
||||||
from .embeddings import TimestepEmbedding, Timesteps
|
from .embeddings import TimestepEmbedding, Timesteps
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
from .unet_2d_blocks import (
|
from .unet_2d_blocks import (
|
||||||
CrossAttnDownBlock2D,
|
CrossAttnDownBlock2D,
|
||||||
CrossAttnUpBlock2D,
|
CrossAttnUpBlock2D,
|
||||||
|
|
|
@ -20,9 +20,9 @@ import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||||
from ..modeling_flax_utils import FlaxModelMixin
|
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||||
|
from .modeling_flax_utils import FlaxModelMixin
|
||||||
from .unet_2d_blocks_flax import (
|
from .unet_2d_blocks_flax import (
|
||||||
FlaxCrossAttnDownBlock2D,
|
FlaxCrossAttnDownBlock2D,
|
||||||
FlaxCrossAttnUpBlock2D,
|
FlaxCrossAttnUpBlock2D,
|
||||||
|
|
|
@ -12,14 +12,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
|
||||||
from ..modeling_utils import ModelMixin
|
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||||
|
|
||||||
|
@ -37,33 +35,6 @@ class DecoderOutput(BaseOutput):
|
||||||
sample: torch.FloatTensor
|
sample: torch.FloatTensor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VQEncoderOutput(BaseOutput):
|
|
||||||
"""
|
|
||||||
Output of VQModel encoding method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
||||||
Encoded output sample of the model. Output of the last layer of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
latents: torch.FloatTensor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AutoencoderKLOutput(BaseOutput):
|
|
||||||
"""
|
|
||||||
Output of AutoencoderKL encoding method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
latent_dist (`DiagonalGaussianDistribution`):
|
|
||||||
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
|
||||||
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
latent_dist: "DiagonalGaussianDistribution"
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -384,255 +355,3 @@ class DiagonalGaussianDistribution(object):
|
||||||
|
|
||||||
def mode(self):
|
def mode(self):
|
||||||
return self.mean
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
class VQModel(ModelMixin, ConfigMixin):
|
|
||||||
r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
|
|
||||||
Kavukcuoglu.
|
|
||||||
|
|
||||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
|
||||||
implements for all the model (such as downloading or saving, etc.)
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
|
||||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
|
||||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
|
||||||
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
|
||||||
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
|
||||||
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
|
||||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
|
||||||
obj:`(64,)`): Tuple of block output channels.
|
|
||||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
|
||||||
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
|
||||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
|
||||||
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
|
||||||
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@register_to_config
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int = 3,
|
|
||||||
out_channels: int = 3,
|
|
||||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
|
||||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
|
||||||
block_out_channels: Tuple[int] = (64,),
|
|
||||||
layers_per_block: int = 1,
|
|
||||||
act_fn: str = "silu",
|
|
||||||
latent_channels: int = 3,
|
|
||||||
sample_size: int = 32,
|
|
||||||
num_vq_embeddings: int = 256,
|
|
||||||
norm_num_groups: int = 32,
|
|
||||||
vq_embed_dim: Optional[int] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# pass init params to Encoder
|
|
||||||
self.encoder = Encoder(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=latent_channels,
|
|
||||||
down_block_types=down_block_types,
|
|
||||||
block_out_channels=block_out_channels,
|
|
||||||
layers_per_block=layers_per_block,
|
|
||||||
act_fn=act_fn,
|
|
||||||
norm_num_groups=norm_num_groups,
|
|
||||||
double_z=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
|
||||||
|
|
||||||
self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
|
||||||
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
|
||||||
self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
|
||||||
|
|
||||||
# pass init params to Decoder
|
|
||||||
self.decoder = Decoder(
|
|
||||||
in_channels=latent_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
up_block_types=up_block_types,
|
|
||||||
block_out_channels=block_out_channels,
|
|
||||||
layers_per_block=layers_per_block,
|
|
||||||
act_fn=act_fn,
|
|
||||||
norm_num_groups=norm_num_groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
|
|
||||||
h = self.encoder(x)
|
|
||||||
h = self.quant_conv(h)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (h,)
|
|
||||||
|
|
||||||
return VQEncoderOutput(latents=h)
|
|
||||||
|
|
||||||
def decode(
|
|
||||||
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
|
|
||||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
||||||
# also go through quantization layer
|
|
||||||
if not force_not_quantize:
|
|
||||||
quant, emb_loss, info = self.quantize(h)
|
|
||||||
else:
|
|
||||||
quant = h
|
|
||||||
quant = self.post_quant_conv(quant)
|
|
||||||
dec = self.decoder(quant)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (dec,)
|
|
||||||
|
|
||||||
return DecoderOutput(sample=dec)
|
|
||||||
|
|
||||||
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
sample (`torch.FloatTensor`): Input sample.
|
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
x = sample
|
|
||||||
h = self.encode(x).latents
|
|
||||||
dec = self.decode(h).sample
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (dec,)
|
|
||||||
|
|
||||||
return DecoderOutput(sample=dec)
|
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
|
||||||
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
|
|
||||||
and Max Welling.
|
|
||||||
|
|
||||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
|
||||||
implements for all the model (such as downloading or saving, etc.)
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
|
||||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
|
||||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
|
||||||
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
|
||||||
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
|
||||||
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
|
||||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
|
||||||
obj:`(64,)`): Tuple of block output channels.
|
|
||||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
|
||||||
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
|
|
||||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
|
||||||
"""
|
|
||||||
|
|
||||||
@register_to_config
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int = 3,
|
|
||||||
out_channels: int = 3,
|
|
||||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
|
||||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
|
||||||
block_out_channels: Tuple[int] = (64,),
|
|
||||||
layers_per_block: int = 1,
|
|
||||||
act_fn: str = "silu",
|
|
||||||
latent_channels: int = 4,
|
|
||||||
norm_num_groups: int = 32,
|
|
||||||
sample_size: int = 32,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# pass init params to Encoder
|
|
||||||
self.encoder = Encoder(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=latent_channels,
|
|
||||||
down_block_types=down_block_types,
|
|
||||||
block_out_channels=block_out_channels,
|
|
||||||
layers_per_block=layers_per_block,
|
|
||||||
act_fn=act_fn,
|
|
||||||
norm_num_groups=norm_num_groups,
|
|
||||||
double_z=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pass init params to Decoder
|
|
||||||
self.decoder = Decoder(
|
|
||||||
in_channels=latent_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
up_block_types=up_block_types,
|
|
||||||
block_out_channels=block_out_channels,
|
|
||||||
layers_per_block=layers_per_block,
|
|
||||||
norm_num_groups=norm_num_groups,
|
|
||||||
act_fn=act_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
|
||||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
|
||||||
self.use_slicing = False
|
|
||||||
|
|
||||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
|
||||||
h = self.encoder(x)
|
|
||||||
moments = self.quant_conv(h)
|
|
||||||
posterior = DiagonalGaussianDistribution(moments)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (posterior,)
|
|
||||||
|
|
||||||
return AutoencoderKLOutput(latent_dist=posterior)
|
|
||||||
|
|
||||||
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
||||||
z = self.post_quant_conv(z)
|
|
||||||
dec = self.decoder(z)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (dec,)
|
|
||||||
|
|
||||||
return DecoderOutput(sample=dec)
|
|
||||||
|
|
||||||
def enable_slicing(self):
|
|
||||||
r"""
|
|
||||||
Enable sliced VAE decoding.
|
|
||||||
|
|
||||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
|
||||||
steps. This is useful to save some memory and allow larger batch sizes.
|
|
||||||
"""
|
|
||||||
self.use_slicing = True
|
|
||||||
|
|
||||||
def disable_slicing(self):
|
|
||||||
r"""
|
|
||||||
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
|
||||||
decoding in one step.
|
|
||||||
"""
|
|
||||||
self.use_slicing = False
|
|
||||||
|
|
||||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
||||||
if self.use_slicing and z.shape[0] > 1:
|
|
||||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
|
||||||
decoded = torch.cat(decoded_slices)
|
|
||||||
else:
|
|
||||||
decoded = self._decode(z).sample
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (decoded,)
|
|
||||||
|
|
||||||
return DecoderOutput(sample=decoded)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
sample: torch.FloatTensor,
|
|
||||||
sample_posterior: bool = False,
|
|
||||||
return_dict: bool = True,
|
|
||||||
generator: Optional[torch.Generator] = None,
|
|
||||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
sample (`torch.FloatTensor`): Input sample.
|
|
||||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to sample from the posterior.
|
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
x = sample
|
|
||||||
posterior = self.encode(x).latent_dist
|
|
||||||
if sample_posterior:
|
|
||||||
z = posterior.sample(generator=generator)
|
|
||||||
else:
|
|
||||||
z = posterior.mode()
|
|
||||||
dec = self.decode(z).sample
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (dec,)
|
|
||||||
|
|
||||||
return DecoderOutput(sample=dec)
|
|
||||||
|
|
|
@ -25,8 +25,8 @@ import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||||
from ..modeling_flax_utils import FlaxModelMixin
|
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
|
from .modeling_flax_utils import FlaxModelMixin
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..utils import BaseOutput
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
|
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VQEncoderOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output of VQModel encoding method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Encoded output sample of the model. Output of the last layer of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latents: torch.FloatTensor
|
||||||
|
|
||||||
|
|
||||||
|
class VQModel(ModelMixin, ConfigMixin):
|
||||||
|
r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
|
||||||
|
Kavukcuoglu.
|
||||||
|
|
||||||
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||||
|
implements for all the model (such as downloading or saving, etc.)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||||
|
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||||
|
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||||
|
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
||||||
|
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||||
|
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
||||||
|
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||||
|
obj:`(64,)`): Tuple of block output channels.
|
||||||
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||||
|
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
||||||
|
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||||
|
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
||||||
|
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 3,
|
||||||
|
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||||
|
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||||
|
block_out_channels: Tuple[int] = (64,),
|
||||||
|
layers_per_block: int = 1,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
latent_channels: int = 3,
|
||||||
|
sample_size: int = 32,
|
||||||
|
num_vq_embeddings: int = 256,
|
||||||
|
norm_num_groups: int = 32,
|
||||||
|
vq_embed_dim: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# pass init params to Encoder
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=latent_channels,
|
||||||
|
down_block_types=down_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
act_fn=act_fn,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
double_z=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
||||||
|
|
||||||
|
self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
||||||
|
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
||||||
|
self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
||||||
|
|
||||||
|
# pass init params to Decoder
|
||||||
|
self.decoder = Decoder(
|
||||||
|
in_channels=latent_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
up_block_types=up_block_types,
|
||||||
|
block_out_channels=block_out_channels,
|
||||||
|
layers_per_block=layers_per_block,
|
||||||
|
act_fn=act_fn,
|
||||||
|
norm_num_groups=norm_num_groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (h,)
|
||||||
|
|
||||||
|
return VQEncoderOutput(latents=h)
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
|
||||||
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
# also go through quantization layer
|
||||||
|
if not force_not_quantize:
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
else:
|
||||||
|
quant = h
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`): Input sample.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
x = sample
|
||||||
|
h = self.encode(x).latents
|
||||||
|
dec = self.decode(h).sample
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
|
@ -1,6 +1,4 @@
|
||||||
# coding=utf-8
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
# Copyright 2022 The HuggingFace Inc. team.
|
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -12,869 +10,10 @@
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import importlib
|
# NOTE: This file is deprecated and will be removed in a future version.
|
||||||
import inspect
|
# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
|
||||||
import torch
|
|
||||||
|
|
||||||
import diffusers
|
|
||||||
import PIL
|
|
||||||
from huggingface_hub import model_info, snapshot_download
|
|
||||||
from packaging import version
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
|
||||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
|
||||||
from .hub_utils import HF_HUB_OFFLINE, http_user_agent
|
|
||||||
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
|
||||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
|
||||||
from .utils import (
|
|
||||||
CONFIG_NAME,
|
|
||||||
DIFFUSERS_CACHE,
|
|
||||||
ONNX_WEIGHTS_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
BaseOutput,
|
|
||||||
deprecate,
|
|
||||||
is_accelerate_available,
|
|
||||||
is_safetensors_available,
|
|
||||||
is_torch_version,
|
|
||||||
is_transformers_available,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
|
||||||
import transformers
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
|
|
||||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
|
||||||
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
|
||||||
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
|
||||||
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
LOADABLE_CLASSES = {
|
|
||||||
"diffusers": {
|
|
||||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
|
||||||
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
|
||||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
|
||||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
|
||||||
},
|
|
||||||
"transformers": {
|
|
||||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
|
||||||
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
|
||||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
|
||||||
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
|
||||||
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
|
||||||
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
|
||||||
},
|
|
||||||
"onnxruntime.training": {
|
|
||||||
"ORTModule": ["save_pretrained", "from_pretrained"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ALL_IMPORTABLE_CLASSES = {}
|
|
||||||
for library in LOADABLE_CLASSES:
|
|
||||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ImagePipelineOutput(BaseOutput):
|
|
||||||
"""
|
|
||||||
Output class for image pipelines.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
|
||||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
|
||||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
|
||||||
"""
|
|
||||||
|
|
||||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AudioPipelineOutput(BaseOutput):
|
|
||||||
"""
|
|
||||||
Output class for audio pipelines.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audios (`np.ndarray`)
|
|
||||||
List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
|
|
||||||
denoised audio samples of the diffusion pipeline.
|
|
||||||
"""
|
|
||||||
|
|
||||||
audios: np.ndarray
|
|
||||||
|
|
||||||
|
|
||||||
def is_safetensors_compatible(info) -> bool:
|
|
||||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
|
||||||
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
|
|
||||||
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
|
|
||||||
for pt_filename in pt_filenames:
|
|
||||||
prefix, raw = os.path.split(pt_filename)
|
|
||||||
if raw == "pytorch_model.bin":
|
|
||||||
# transformers specific
|
|
||||||
sf_filename = os.path.join(prefix, "model.safetensors")
|
|
||||||
else:
|
|
||||||
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
|
|
||||||
if is_safetensors_compatible and sf_filename not in filenames:
|
|
||||||
logger.warning(f"{sf_filename} not found")
|
|
||||||
is_safetensors_compatible = False
|
|
||||||
return is_safetensors_compatible
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPipeline(ConfigMixin):
|
|
||||||
r"""
|
|
||||||
Base class for all models.
|
|
||||||
|
|
||||||
[`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
|
|
||||||
and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
|
|
||||||
|
|
||||||
- move all PyTorch modules to the device of your choice
|
|
||||||
- enabling/disabling the progress bar for the denoising iteration
|
|
||||||
|
|
||||||
Class attributes:
|
|
||||||
|
|
||||||
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
|
|
||||||
components of the diffusion pipeline.
|
|
||||||
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
|
|
||||||
passed for the pipeline to function (should be overridden by subclasses).
|
|
||||||
"""
|
|
||||||
config_name = "model_index.json"
|
|
||||||
_optional_components = []
|
|
||||||
|
|
||||||
def register_modules(self, **kwargs):
|
|
||||||
# import it here to avoid circular import
|
|
||||||
from diffusers import pipelines
|
|
||||||
|
|
||||||
for name, module in kwargs.items():
|
|
||||||
# retrieve library
|
|
||||||
if module is None:
|
|
||||||
register_dict = {name: (None, None)}
|
|
||||||
else:
|
|
||||||
library = module.__module__.split(".")[0]
|
|
||||||
|
|
||||||
# check if the module is a pipeline module
|
|
||||||
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
|
|
||||||
path = module.__module__.split(".")
|
|
||||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
|
||||||
|
|
||||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
|
||||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
|
||||||
# folder so we set the library to module name.
|
|
||||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
|
||||||
library = pipeline_dir
|
|
||||||
|
|
||||||
# retrieve class_name
|
|
||||||
class_name = module.__class__.__name__
|
|
||||||
|
|
||||||
register_dict = {name: (library, class_name)}
|
|
||||||
|
|
||||||
# save model index config
|
|
||||||
self.register_to_config(**register_dict)
|
|
||||||
|
|
||||||
# set models
|
|
||||||
setattr(self, name, module)
|
|
||||||
|
|
||||||
def save_pretrained(
|
|
||||||
self,
|
|
||||||
save_directory: Union[str, os.PathLike],
|
|
||||||
safe_serialization: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
|
|
||||||
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
|
|
||||||
method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
save_directory (`str` or `os.PathLike`):
|
|
||||||
Directory to which to save. Will be created if it doesn't exist.
|
|
||||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
|
||||||
"""
|
|
||||||
self.save_config(save_directory)
|
|
||||||
|
|
||||||
model_index_dict = dict(self.config)
|
|
||||||
model_index_dict.pop("_class_name")
|
|
||||||
model_index_dict.pop("_diffusers_version")
|
|
||||||
model_index_dict.pop("_module", None)
|
|
||||||
|
|
||||||
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
|
||||||
|
|
||||||
def is_saveable_module(name, value):
|
|
||||||
if name not in expected_modules:
|
|
||||||
return False
|
|
||||||
if name in self._optional_components and value[0] is None:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
|
||||||
|
|
||||||
for pipeline_component_name in model_index_dict.keys():
|
|
||||||
sub_model = getattr(self, pipeline_component_name)
|
|
||||||
model_cls = sub_model.__class__
|
|
||||||
|
|
||||||
save_method_name = None
|
|
||||||
# search for the model's base class in LOADABLE_CLASSES
|
|
||||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
|
||||||
library = importlib.import_module(library_name)
|
|
||||||
for base_class, save_load_methods in library_classes.items():
|
|
||||||
class_candidate = getattr(library, base_class, None)
|
|
||||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
|
||||||
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
|
|
||||||
save_method_name = save_load_methods[0]
|
|
||||||
break
|
|
||||||
if save_method_name is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
save_method = getattr(sub_model, save_method_name)
|
|
||||||
|
|
||||||
# Call the save method with the argument safe_serialization only if it's supported
|
|
||||||
save_method_signature = inspect.signature(save_method)
|
|
||||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
|
||||||
if save_method_accept_safe:
|
|
||||||
save_method(
|
|
||||||
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
save_method(os.path.join(save_directory, pipeline_component_name))
|
|
||||||
|
|
||||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
|
||||||
if torch_device is None:
|
|
||||||
return self
|
|
||||||
|
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
|
||||||
for name in module_names.keys():
|
|
||||||
module = getattr(self, name)
|
|
||||||
if isinstance(module, torch.nn.Module):
|
|
||||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
|
|
||||||
logger.warning(
|
|
||||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
|
||||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
|
||||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
|
||||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
|
||||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
|
||||||
)
|
|
||||||
module.to(torch_device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
`torch.device`: The torch device on which the pipeline is located.
|
|
||||||
"""
|
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
|
||||||
for name in module_names.keys():
|
|
||||||
module = getattr(self, name)
|
|
||||||
if isinstance(module, torch.nn.Module):
|
|
||||||
return module.device
|
|
||||||
return torch.device("cpu")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
|
||||||
r"""
|
|
||||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
|
||||||
|
|
||||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
|
||||||
|
|
||||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
|
||||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
|
||||||
task.
|
|
||||||
|
|
||||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
|
||||||
weights are discarded.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
|
||||||
Can be either:
|
|
||||||
|
|
||||||
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
|
|
||||||
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
|
|
||||||
`CompVis/ldm-text2im-large-256`.
|
|
||||||
- A path to a *directory* containing pipeline weights saved using
|
|
||||||
[`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
|
|
||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
|
||||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
|
||||||
will be automatically derived from the model's weights.
|
|
||||||
custom_pipeline (`str`, *optional*):
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
This is an experimental feature and is likely to change in the future.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Can be either:
|
|
||||||
|
|
||||||
- A string, the *repo id* of a custom pipeline hosted inside a model repo on
|
|
||||||
https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
|
|
||||||
like `hf-internal-testing/diffusers-dummy-pipeline`.
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
It is required that the model repo has a file, called `pipeline.py` that defines the custom
|
|
||||||
pipeline.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
- A string, the *file name* of a community pipeline hosted on GitHub under
|
|
||||||
https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
|
|
||||||
match exactly the file name without `.py` located under the above link, *e.g.*
|
|
||||||
`clip_guided_stable_diffusion`.
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
Community pipelines are always loaded from the current `main` branch of GitHub.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
- A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
It is required that the directory has a file, called `pipeline.py` that defines the custom
|
|
||||||
pipeline.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
For more information on how to load and create custom pipelines, please have a look at [Loading and
|
|
||||||
Adding Custom
|
|
||||||
Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
|
|
||||||
|
|
||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
|
||||||
force_download (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
|
||||||
cached versions if they exist.
|
|
||||||
resume_download (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
|
||||||
file exists.
|
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
|
||||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
|
||||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
|
||||||
use_auth_token (`str` or *bool*, *optional*):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
|
||||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
|
||||||
identifier allowed by git.
|
|
||||||
custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of `diffusers` when loading from GitHub):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
|
||||||
`revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a
|
|
||||||
custom pipeline from GitHub.
|
|
||||||
mirror (`str`, *optional*):
|
|
||||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
|
||||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
|
||||||
Please refer to the mirror site for more information. specify the folder name here.
|
|
||||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
|
||||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
|
||||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
|
||||||
same device.
|
|
||||||
|
|
||||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
|
||||||
more information about each option see [designing a device
|
|
||||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
|
||||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
|
||||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
|
||||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
|
||||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
|
||||||
setting this argument to `True` will raise an error.
|
|
||||||
return_cached_folder (`bool`, *optional*, defaults to `False`):
|
|
||||||
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
|
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
|
||||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
|
||||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
|
||||||
`__init__` method. See example below for more information.
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
|
||||||
models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
|
||||||
this method in a firewalled environment.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```py
|
|
||||||
>>> from diffusers import DiffusionPipeline
|
|
||||||
|
|
||||||
>>> # Download pipeline from huggingface.co and cache.
|
|
||||||
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
|
||||||
|
|
||||||
>>> # Download pipeline that requires an authorization token
|
|
||||||
>>> # For more information on access tokens, please refer to this section
|
|
||||||
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
|
|
||||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
|
||||||
|
|
||||||
>>> # Use a different scheduler
|
|
||||||
>>> from diffusers import LMSDiscreteScheduler
|
|
||||||
|
|
||||||
>>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
|
||||||
>>> pipeline.scheduler = scheduler
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
|
||||||
force_download = kwargs.pop("force_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
|
||||||
revision = kwargs.pop("revision", None)
|
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
|
||||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
|
||||||
custom_revision = kwargs.pop("custom_revision", None)
|
|
||||||
provider = kwargs.pop("provider", None)
|
|
||||||
sess_options = kwargs.pop("sess_options", None)
|
|
||||||
device_map = kwargs.pop("device_map", None)
|
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
|
||||||
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
|
||||||
|
|
||||||
# 1. Download the checkpoints and configs
|
|
||||||
# use snapshot download here to get it working from from_pretrained
|
|
||||||
if not os.path.isdir(pretrained_model_name_or_path):
|
|
||||||
config_dict = cls.load_config(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
resume_download=resume_download,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
# make sure we only download sub-folders and `diffusers` filenames
|
|
||||||
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
|
|
||||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
|
||||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
|
||||||
|
|
||||||
# make sure we don't download flax weights
|
|
||||||
ignore_patterns = ["*.msgpack"]
|
|
||||||
|
|
||||||
if custom_pipeline is not None:
|
|
||||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
|
||||||
|
|
||||||
if cls != DiffusionPipeline:
|
|
||||||
requested_pipeline_class = cls.__name__
|
|
||||||
else:
|
|
||||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
|
||||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
|
||||||
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
|
||||||
user_agent["custom_pipeline"] = custom_pipeline
|
|
||||||
|
|
||||||
user_agent = http_user_agent(user_agent)
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
|
||||||
info = model_info(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
if is_safetensors_compatible(info):
|
|
||||||
ignore_patterns.append("*.bin")
|
|
||||||
else:
|
|
||||||
ignore_patterns.append("*.safetensors")
|
|
||||||
|
|
||||||
# download all allow_patterns
|
|
||||||
cached_folder = snapshot_download(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
ignore_patterns=ignore_patterns,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cached_folder = pretrained_model_name_or_path
|
|
||||||
|
|
||||||
config_dict = cls.load_config(cached_folder)
|
|
||||||
|
|
||||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
|
||||||
# if we load from explicit class, let's use it
|
|
||||||
if custom_pipeline is not None:
|
|
||||||
if custom_pipeline.endswith(".py"):
|
|
||||||
path = Path(custom_pipeline)
|
|
||||||
# decompose into folder & file
|
|
||||||
file_name = path.name
|
|
||||||
custom_pipeline = path.parent.absolute()
|
|
||||||
else:
|
|
||||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
|
||||||
|
|
||||||
pipeline_class = get_class_from_dynamic_module(
|
|
||||||
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision
|
|
||||||
)
|
|
||||||
elif cls != DiffusionPipeline:
|
|
||||||
pipeline_class = cls
|
|
||||||
else:
|
|
||||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
|
||||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
|
||||||
|
|
||||||
# To be removed in 1.0.0
|
|
||||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
|
||||||
version.parse(config_dict["_diffusers_version"]).base_version
|
|
||||||
) <= version.parse("0.5.1"):
|
|
||||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
|
||||||
|
|
||||||
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
|
||||||
|
|
||||||
deprecation_message = (
|
|
||||||
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
|
||||||
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
|
||||||
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
|
||||||
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
|
||||||
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
|
||||||
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
|
||||||
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
|
||||||
)
|
|
||||||
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
|
||||||
|
|
||||||
# some modules can be passed directly to the init
|
|
||||||
# in this case they are already instantiated in `kwargs`
|
|
||||||
# extract them here
|
|
||||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
|
||||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
|
||||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
|
||||||
|
|
||||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
|
||||||
|
|
||||||
# define init kwargs
|
|
||||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
|
||||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
|
||||||
|
|
||||||
# remove `null` components
|
|
||||||
def load_module(name, value):
|
|
||||||
if value[0] is None:
|
|
||||||
return False
|
|
||||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
|
||||||
|
|
||||||
if len(unused_kwargs) > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage and not is_accelerate_available():
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `device_map=None`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is False and device_map is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
|
||||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# import it here to avoid circular import
|
|
||||||
from diffusers import pipelines
|
|
||||||
|
|
||||||
# 3. Load each module in the pipeline
|
|
||||||
for name, (library_name, class_name) in init_dict.items():
|
|
||||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
|
||||||
if class_name.startswith("Flax"):
|
|
||||||
class_name = class_name[4:]
|
|
||||||
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
|
||||||
loaded_sub_model = None
|
|
||||||
|
|
||||||
# if the model is in a pipeline module, then we load it from the pipeline
|
|
||||||
if name in passed_class_obj:
|
|
||||||
# 1. check that passed_class_obj has correct parent class
|
|
||||||
if not is_pipeline_module:
|
|
||||||
library = importlib.import_module(library_name)
|
|
||||||
class_obj = getattr(library, class_name)
|
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
|
||||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
|
||||||
|
|
||||||
expected_class_obj = None
|
|
||||||
for class_name, class_candidate in class_candidates.items():
|
|
||||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
|
||||||
expected_class_obj = class_candidate
|
|
||||||
|
|
||||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
|
||||||
raise ValueError(
|
|
||||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
|
||||||
f" {expected_class_obj}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
|
||||||
" has the correct type"
|
|
||||||
)
|
|
||||||
|
|
||||||
# set passed class object
|
|
||||||
loaded_sub_model = passed_class_obj[name]
|
|
||||||
elif is_pipeline_module:
|
|
||||||
pipeline_module = getattr(pipelines, library_name)
|
|
||||||
class_obj = getattr(pipeline_module, class_name)
|
|
||||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
|
||||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
|
||||||
else:
|
|
||||||
# else we just import it from the library.
|
|
||||||
library = importlib.import_module(library_name)
|
|
||||||
|
|
||||||
class_obj = getattr(library, class_name)
|
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
|
||||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
|
||||||
|
|
||||||
if loaded_sub_model is None:
|
|
||||||
load_method_name = None
|
|
||||||
for class_name, class_candidate in class_candidates.items():
|
|
||||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
|
||||||
load_method_name = importable_classes[class_name][1]
|
|
||||||
|
|
||||||
if load_method_name is None:
|
|
||||||
none_module = class_obj.__module__
|
|
||||||
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
|
||||||
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
|
||||||
)
|
|
||||||
if is_dummy_path and "dummy" in none_module:
|
|
||||||
# call class_obj for nice error message of missing requirements
|
|
||||||
class_obj()
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
|
||||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
|
||||||
)
|
|
||||||
|
|
||||||
load_method = getattr(class_obj, load_method_name)
|
|
||||||
loading_kwargs = {}
|
|
||||||
|
|
||||||
if issubclass(class_obj, torch.nn.Module):
|
|
||||||
loading_kwargs["torch_dtype"] = torch_dtype
|
|
||||||
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
|
||||||
loading_kwargs["provider"] = provider
|
|
||||||
loading_kwargs["sess_options"] = sess_options
|
|
||||||
|
|
||||||
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
|
||||||
is_transformers_model = (
|
|
||||||
is_transformers_available()
|
|
||||||
and issubclass(class_obj, PreTrainedModel)
|
|
||||||
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
|
||||||
)
|
|
||||||
|
|
||||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
|
||||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
|
||||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
|
||||||
if is_diffusers_model or is_transformers_model:
|
|
||||||
loading_kwargs["device_map"] = device_map
|
|
||||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
|
||||||
|
|
||||||
# check if the module is in a subdirectory
|
|
||||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
|
||||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
|
||||||
else:
|
|
||||||
# else load from the root directory
|
|
||||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
|
||||||
|
|
||||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
|
||||||
|
|
||||||
# 4. Potentially add passed objects if expected
|
|
||||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
|
||||||
passed_modules = list(passed_class_obj.keys())
|
|
||||||
optional_modules = pipeline_class._optional_components
|
|
||||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
|
||||||
for module in missing_modules:
|
|
||||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
|
||||||
elif len(missing_modules) > 0:
|
|
||||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
|
||||||
raise ValueError(
|
|
||||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Instantiate the pipeline
|
|
||||||
model = pipeline_class(**init_kwargs)
|
|
||||||
|
|
||||||
if return_cached_folder:
|
|
||||||
return model, cached_folder
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_signature_keys(obj):
|
|
||||||
parameters = inspect.signature(obj.__init__).parameters
|
|
||||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
|
||||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
|
||||||
expected_modules = set(required_parameters.keys()) - set(["self"])
|
|
||||||
return expected_modules, optional_parameters
|
|
||||||
|
|
||||||
@property
|
|
||||||
def components(self) -> Dict[str, Any]:
|
|
||||||
r"""
|
|
||||||
|
|
||||||
The `self.components` property can be useful to run different pipelines with the same weights and
|
|
||||||
configurations to not have to re-allocate memory.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```py
|
|
||||||
>>> from diffusers import (
|
|
||||||
... StableDiffusionPipeline,
|
|
||||||
... StableDiffusionImg2ImgPipeline,
|
|
||||||
... StableDiffusionInpaintPipeline,
|
|
||||||
... )
|
|
||||||
|
|
||||||
>>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
|
||||||
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
|
|
||||||
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
|
|
||||||
```
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing all the modules needed to initialize the pipeline.
|
|
||||||
"""
|
|
||||||
expected_modules, optional_parameters = self._get_signature_keys(self)
|
|
||||||
components = {
|
|
||||||
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
|
||||||
}
|
|
||||||
|
|
||||||
if set(components.keys()) != expected_modules:
|
|
||||||
raise ValueError(
|
|
||||||
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
|
||||||
f" {expected_modules} to be defined, but {components} are defined."
|
|
||||||
)
|
|
||||||
|
|
||||||
return components
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def numpy_to_pil(images):
|
|
||||||
"""
|
|
||||||
Convert a numpy image or a batch of images to a PIL image.
|
|
||||||
"""
|
|
||||||
if images.ndim == 3:
|
|
||||||
images = images[None, ...]
|
|
||||||
images = (images * 255).round().astype("uint8")
|
|
||||||
if images.shape[-1] == 1:
|
|
||||||
# special case for grayscale (single channel) images
|
|
||||||
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
|
||||||
else:
|
|
||||||
pil_images = [Image.fromarray(image) for image in images]
|
|
||||||
|
|
||||||
return pil_images
|
|
||||||
|
|
||||||
def progress_bar(self, iterable=None, total=None):
|
|
||||||
if not hasattr(self, "_progress_bar_config"):
|
|
||||||
self._progress_bar_config = {}
|
|
||||||
elif not isinstance(self._progress_bar_config, dict):
|
|
||||||
raise ValueError(
|
|
||||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if iterable is not None:
|
|
||||||
return tqdm(iterable, **self._progress_bar_config)
|
|
||||||
elif total is not None:
|
|
||||||
return tqdm(total=total, **self._progress_bar_config)
|
|
||||||
else:
|
|
||||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
|
||||||
|
|
||||||
def set_progress_bar_config(self, **kwargs):
|
|
||||||
self._progress_bar_config = kwargs
|
|
||||||
|
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
|
||||||
r"""
|
|
||||||
Enable memory efficient attention as implemented in xformers.
|
|
||||||
|
|
||||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
|
||||||
time. Speed up at training time is not guaranteed.
|
|
||||||
|
|
||||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
|
||||||
is used.
|
|
||||||
"""
|
|
||||||
self.set_use_memory_efficient_attention_xformers(True)
|
|
||||||
|
|
||||||
def disable_xformers_memory_efficient_attention(self):
|
|
||||||
r"""
|
|
||||||
Disable memory efficient attention as implemented in xformers.
|
|
||||||
"""
|
|
||||||
self.set_use_memory_efficient_attention_xformers(False)
|
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
|
||||||
# Recursively walk through all the children.
|
|
||||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
|
||||||
# gets the message
|
|
||||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
|
||||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
|
||||||
module.set_use_memory_efficient_attention_xformers(valid)
|
|
||||||
|
|
||||||
for child in module.children():
|
|
||||||
fn_recursive_set_mem_eff(child)
|
|
||||||
|
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
|
||||||
for module_name in module_names:
|
|
||||||
module = getattr(self, module_name)
|
|
||||||
if isinstance(module, torch.nn.Module):
|
|
||||||
fn_recursive_set_mem_eff(module)
|
|
||||||
|
|
||||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
|
||||||
r"""
|
|
||||||
Enable sliced attention computation.
|
|
||||||
|
|
||||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
|
||||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
|
||||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
|
||||||
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
|
||||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
|
||||||
must be a multiple of `slice_size`.
|
|
||||||
"""
|
|
||||||
self.set_attention_slice(slice_size)
|
|
||||||
|
|
||||||
def disable_attention_slicing(self):
|
|
||||||
r"""
|
|
||||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
|
||||||
back to computing attention in one step.
|
|
||||||
"""
|
|
||||||
# set slice_size = `None` to disable `attention slicing`
|
|
||||||
self.enable_attention_slicing(None)
|
|
||||||
|
|
||||||
def set_attention_slice(self, slice_size: Optional[int]):
|
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
|
||||||
for module_name in module_names:
|
|
||||||
module = getattr(self, module_name)
|
|
||||||
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
|
|
||||||
module.set_attention_slice(slice_size)
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ else:
|
||||||
from .ddpm import DDPMPipeline
|
from .ddpm import DDPMPipeline
|
||||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||||
from .latent_diffusion_uncond import LDMPipeline
|
from .latent_diffusion_uncond import LDMPipeline
|
||||||
|
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
|
||||||
from .pndm import PNDMPipeline
|
from .pndm import PNDMPipeline
|
||||||
from .repaint import RePaintPipeline
|
from .repaint import RePaintPipeline
|
||||||
from .score_sde_ve import ScoreSdeVePipeline
|
from .score_sde_ve import ScoreSdeVePipeline
|
||||||
|
@ -62,6 +63,14 @@ else:
|
||||||
)
|
)
|
||||||
from .vq_diffusion import VQDiffusionPipeline
|
from .vq_diffusion import VQDiffusionPipeline
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_onnx_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ..utils.dummy_onnx_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .onnx_utils import OnnxRuntimeModel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
|
@ -84,6 +93,14 @@ except OptionalDependencyNotAvailable:
|
||||||
else:
|
else:
|
||||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_flax_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ..utils.dummy_flax_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not (is_flax_available() and is_transformers_available()):
|
if not (is_flax_available() and is_transformers_available()):
|
||||||
|
|
|
@ -23,7 +23,6 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -33,6 +32,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -35,6 +34,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,8 @@ import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import DDIMScheduler, DDPMScheduler
|
from ...schedulers import DDIMScheduler, DDPMScheduler
|
||||||
|
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
|
||||||
from .mel import Mel
|
from .mel import Mel
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -63,12 +63,11 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||||
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
|
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
|
||||||
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
|
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if
|
[`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if audio_length_in_s is None:
|
if audio_length_in_s is None:
|
||||||
|
|
|
@ -16,8 +16,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...utils import deprecate
|
from ...utils import deprecate
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
class DDIMPipeline(DiffusionPipeline):
|
class DDIMPipeline(DiffusionPipeline):
|
||||||
|
@ -66,12 +66,11 @@ class DDIMPipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -18,8 +18,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...utils import deprecate
|
from ...utils import deprecate
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
class DDPMPipeline(DiffusionPipeline):
|
class DDPMPipeline(DiffusionPipeline):
|
||||||
|
@ -62,12 +62,11 @@ class DDPMPipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
message = (
|
message = (
|
||||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||||
|
|
|
@ -19,16 +19,14 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.modeling_outputs import BaseModelOutput
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
class LDMTextToImagePipeline(DiffusionPipeline):
|
class LDMTextToImagePipeline(DiffusionPipeline):
|
||||||
|
@ -105,12 +103,11 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
# 0. Default height and width to unet
|
# 0. Default height and width to unet
|
||||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
|
|
@ -8,7 +8,6 @@ import torch.utils.checkpoint
|
||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
from ...models import UNet2DModel, VQModel
|
from ...models import UNet2DModel, VQModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -18,6 +17,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate
|
from ...utils import PIL_INTERPOLATION, deprecate
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
def preprocess(image):
|
def preprocess(image):
|
||||||
|
@ -95,12 +95,11 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
message = "Please use `image` instead of `init_image`."
|
message = "Please use `image` instead of `init_image`."
|
||||||
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
|
||||||
|
|
|
@ -18,8 +18,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...models import UNet2DModel, VQModel
|
from ...models import UNet2DModel, VQModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import DDIMScheduler
|
from ...schedulers import DDIMScheduler
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
class LDMPipeline(DiffusionPipeline):
|
class LDMPipeline(DiffusionPipeline):
|
||||||
|
@ -64,12 +64,11 @@ class LDMPipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
latents = torch.randn(
|
latents = torch.randn(
|
||||||
|
|
|
@ -24,7 +24,7 @@ import numpy as np
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
||||||
|
|
||||||
|
|
||||||
if is_onnx_available():
|
if is_onnx_available():
|
|
@ -23,9 +23,9 @@ from diffusers.utils import is_accelerate_available
|
||||||
from transformers import CLIPFeatureExtractor
|
from transformers import CLIPFeatureExtractor
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from .image_encoder import PaintByExampleImageEncoder
|
from .image_encoder import PaintByExampleImageEncoder
|
||||||
|
|
|
@ -28,11 +28,10 @@ from huggingface_hub import snapshot_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from ..configuration_utils import ConfigMixin
|
||||||
from .hub_utils import http_user_agent
|
from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
||||||
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
|
||||||
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
|
from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging
|
||||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
|
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
|
@ -0,0 +1,881 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
import PIL
|
||||||
|
from huggingface_hub import model_info, snapshot_download
|
||||||
|
from packaging import version
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin
|
||||||
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||||
|
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
|
from ..utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
DIFFUSERS_CACHE,
|
||||||
|
HF_HUB_OFFLINE,
|
||||||
|
ONNX_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
BaseOutput,
|
||||||
|
deprecate,
|
||||||
|
get_class_from_dynamic_module,
|
||||||
|
http_user_agent,
|
||||||
|
is_accelerate_available,
|
||||||
|
is_safetensors_available,
|
||||||
|
is_torch_version,
|
||||||
|
is_transformers_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
import transformers
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||||
|
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
||||||
|
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
||||||
|
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
LOADABLE_CLASSES = {
|
||||||
|
"diffusers": {
|
||||||
|
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||||
|
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||||
|
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||||
|
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||||
|
},
|
||||||
|
"transformers": {
|
||||||
|
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||||
|
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
||||||
|
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||||
|
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
||||||
|
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
||||||
|
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
||||||
|
},
|
||||||
|
"onnxruntime.training": {
|
||||||
|
"ORTModule": ["save_pretrained", "from_pretrained"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ALL_IMPORTABLE_CLASSES = {}
|
||||||
|
for library in LOADABLE_CLASSES:
|
||||||
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImagePipelineOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output class for image pipelines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||||
|
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||||
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioPipelineOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output class for audio pipelines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audios (`np.ndarray`)
|
||||||
|
List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
|
||||||
|
denoised audio samples of the diffusion pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
audios: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
def is_safetensors_compatible(info) -> bool:
|
||||||
|
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||||
|
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
|
||||||
|
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
|
||||||
|
for pt_filename in pt_filenames:
|
||||||
|
prefix, raw = os.path.split(pt_filename)
|
||||||
|
if raw == "pytorch_model.bin":
|
||||||
|
# transformers specific
|
||||||
|
sf_filename = os.path.join(prefix, "model.safetensors")
|
||||||
|
else:
|
||||||
|
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
|
||||||
|
if is_safetensors_compatible and sf_filename not in filenames:
|
||||||
|
logger.warning(f"{sf_filename} not found")
|
||||||
|
is_safetensors_compatible = False
|
||||||
|
return is_safetensors_compatible
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPipeline(ConfigMixin):
|
||||||
|
r"""
|
||||||
|
Base class for all models.
|
||||||
|
|
||||||
|
[`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
|
||||||
|
and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
|
||||||
|
|
||||||
|
- move all PyTorch modules to the device of your choice
|
||||||
|
- enabling/disabling the progress bar for the denoising iteration
|
||||||
|
|
||||||
|
Class attributes:
|
||||||
|
|
||||||
|
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
|
||||||
|
components of the diffusion pipeline.
|
||||||
|
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
|
||||||
|
passed for the pipeline to function (should be overridden by subclasses).
|
||||||
|
"""
|
||||||
|
config_name = "model_index.json"
|
||||||
|
_optional_components = []
|
||||||
|
|
||||||
|
def register_modules(self, **kwargs):
|
||||||
|
# import it here to avoid circular import
|
||||||
|
from diffusers import pipelines
|
||||||
|
|
||||||
|
for name, module in kwargs.items():
|
||||||
|
# retrieve library
|
||||||
|
if module is None:
|
||||||
|
register_dict = {name: (None, None)}
|
||||||
|
else:
|
||||||
|
library = module.__module__.split(".")[0]
|
||||||
|
|
||||||
|
# check if the module is a pipeline module
|
||||||
|
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
|
||||||
|
path = module.__module__.split(".")
|
||||||
|
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||||
|
|
||||||
|
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||||
|
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||||
|
# folder so we set the library to module name.
|
||||||
|
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||||
|
library = pipeline_dir
|
||||||
|
|
||||||
|
# retrieve class_name
|
||||||
|
class_name = module.__class__.__name__
|
||||||
|
|
||||||
|
register_dict = {name: (library, class_name)}
|
||||||
|
|
||||||
|
# save model index config
|
||||||
|
self.register_to_config(**register_dict)
|
||||||
|
|
||||||
|
# set models
|
||||||
|
setattr(self, name, module)
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
safe_serialization: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
|
||||||
|
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
|
||||||
|
method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
save_directory (`str` or `os.PathLike`):
|
||||||
|
Directory to which to save. Will be created if it doesn't exist.
|
||||||
|
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||||
|
"""
|
||||||
|
self.save_config(save_directory)
|
||||||
|
|
||||||
|
model_index_dict = dict(self.config)
|
||||||
|
model_index_dict.pop("_class_name")
|
||||||
|
model_index_dict.pop("_diffusers_version")
|
||||||
|
model_index_dict.pop("_module", None)
|
||||||
|
|
||||||
|
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
||||||
|
|
||||||
|
def is_saveable_module(name, value):
|
||||||
|
if name not in expected_modules:
|
||||||
|
return False
|
||||||
|
if name in self._optional_components and value[0] is None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
||||||
|
|
||||||
|
for pipeline_component_name in model_index_dict.keys():
|
||||||
|
sub_model = getattr(self, pipeline_component_name)
|
||||||
|
model_cls = sub_model.__class__
|
||||||
|
|
||||||
|
save_method_name = None
|
||||||
|
# search for the model's base class in LOADABLE_CLASSES
|
||||||
|
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||||
|
library = importlib.import_module(library_name)
|
||||||
|
for base_class, save_load_methods in library_classes.items():
|
||||||
|
class_candidate = getattr(library, base_class, None)
|
||||||
|
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||||
|
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
|
||||||
|
save_method_name = save_load_methods[0]
|
||||||
|
break
|
||||||
|
if save_method_name is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
save_method = getattr(sub_model, save_method_name)
|
||||||
|
|
||||||
|
# Call the save method with the argument safe_serialization only if it's supported
|
||||||
|
save_method_signature = inspect.signature(save_method)
|
||||||
|
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||||
|
if save_method_accept_safe:
|
||||||
|
save_method(
|
||||||
|
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
save_method(os.path.join(save_directory, pipeline_component_name))
|
||||||
|
|
||||||
|
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||||
|
if torch_device is None:
|
||||||
|
return self
|
||||||
|
|
||||||
|
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||||
|
for name in module_names.keys():
|
||||||
|
module = getattr(self, name)
|
||||||
|
if isinstance(module, torch.nn.Module):
|
||||||
|
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
|
||||||
|
logger.warning(
|
||||||
|
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||||
|
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||||
|
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||||
|
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||||
|
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||||
|
)
|
||||||
|
module.to(torch_device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`torch.device`: The torch device on which the pipeline is located.
|
||||||
|
"""
|
||||||
|
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||||
|
for name in module_names.keys():
|
||||||
|
module = getattr(self, name)
|
||||||
|
if isinstance(module, torch.nn.Module):
|
||||||
|
return module.device
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||||
|
r"""
|
||||||
|
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
||||||
|
|
||||||
|
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||||
|
|
||||||
|
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||||
|
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||||
|
task.
|
||||||
|
|
||||||
|
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||||
|
weights are discarded.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
|
||||||
|
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
|
||||||
|
`CompVis/ldm-text2im-large-256`.
|
||||||
|
- A path to a *directory* containing pipeline weights saved using
|
||||||
|
[`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
|
||||||
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
|
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||||
|
will be automatically derived from the model's weights.
|
||||||
|
custom_pipeline (`str`, *optional*):
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This is an experimental feature and is likely to change in the future.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the *repo id* of a custom pipeline hosted inside a model repo on
|
||||||
|
https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
|
||||||
|
like `hf-internal-testing/diffusers-dummy-pipeline`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It is required that the model repo has a file, called `pipeline.py` that defines the custom
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
- A string, the *file name* of a community pipeline hosted on GitHub under
|
||||||
|
https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
|
||||||
|
match exactly the file name without `.py` located under the above link, *e.g.*
|
||||||
|
`clip_guided_stable_diffusion`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Community pipelines are always loaded from the current `main` branch of GitHub.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
- A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It is required that the directory has a file, called `pipeline.py` that defines the custom
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
For more information on how to load and create custom pipelines, please have a look at [Loading and
|
||||||
|
Adding Custom
|
||||||
|
Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
|
||||||
|
|
||||||
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
resume_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||||
|
file exists.
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||||
|
use_auth_token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||||
|
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of `diffusers` when loading from GitHub):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||||
|
`revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a
|
||||||
|
custom pipeline from GitHub.
|
||||||
|
mirror (`str`, *optional*):
|
||||||
|
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||||
|
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||||
|
Please refer to the mirror site for more information. specify the folder name here.
|
||||||
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||||
|
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||||
|
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||||
|
same device.
|
||||||
|
|
||||||
|
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||||
|
more information about each option see [designing a device
|
||||||
|
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||||
|
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||||
|
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||||
|
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||||
|
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||||
|
setting this argument to `True` will raise an error.
|
||||||
|
return_cached_folder (`bool`, *optional*, defaults to `False`):
|
||||||
|
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
|
||||||
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
|
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||||
|
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||||
|
`__init__` method. See example below for more information.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||||
|
models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
||||||
|
this method in a firewalled environment.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
>>> # Download pipeline from huggingface.co and cache.
|
||||||
|
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||||
|
|
||||||
|
>>> # Download pipeline that requires an authorization token
|
||||||
|
>>> # For more information on access tokens, please refer to this section
|
||||||
|
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
|
||||||
|
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||||
|
|
||||||
|
>>> # Use a different scheduler
|
||||||
|
>>> from diffusers import LMSDiscreteScheduler
|
||||||
|
|
||||||
|
>>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||||
|
>>> pipeline.scheduler = scheduler
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||||
|
custom_revision = kwargs.pop("custom_revision", None)
|
||||||
|
provider = kwargs.pop("provider", None)
|
||||||
|
sess_options = kwargs.pop("sess_options", None)
|
||||||
|
device_map = kwargs.pop("device_map", None)
|
||||||
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
||||||
|
|
||||||
|
# 1. Download the checkpoints and configs
|
||||||
|
# use snapshot download here to get it working from from_pretrained
|
||||||
|
if not os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
config_dict = cls.load_config(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
resume_download=resume_download,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
# make sure we only download sub-folders and `diffusers` filenames
|
||||||
|
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
|
||||||
|
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||||
|
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
||||||
|
|
||||||
|
# make sure we don't download flax weights
|
||||||
|
ignore_patterns = ["*.msgpack"]
|
||||||
|
|
||||||
|
if custom_pipeline is not None:
|
||||||
|
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||||
|
|
||||||
|
if cls != DiffusionPipeline:
|
||||||
|
requested_pipeline_class = cls.__name__
|
||||||
|
else:
|
||||||
|
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||||
|
user_agent = {"pipeline_class": requested_pipeline_class}
|
||||||
|
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
||||||
|
user_agent["custom_pipeline"] = custom_pipeline
|
||||||
|
|
||||||
|
user_agent = http_user_agent(user_agent)
|
||||||
|
|
||||||
|
if is_safetensors_available():
|
||||||
|
info = model_info(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
if is_safetensors_compatible(info):
|
||||||
|
ignore_patterns.append("*.bin")
|
||||||
|
else:
|
||||||
|
ignore_patterns.append("*.safetensors")
|
||||||
|
|
||||||
|
# download all allow_patterns
|
||||||
|
cached_folder = snapshot_download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cached_folder = pretrained_model_name_or_path
|
||||||
|
|
||||||
|
config_dict = cls.load_config(cached_folder)
|
||||||
|
|
||||||
|
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||||
|
# if we load from explicit class, let's use it
|
||||||
|
if custom_pipeline is not None:
|
||||||
|
if custom_pipeline.endswith(".py"):
|
||||||
|
path = Path(custom_pipeline)
|
||||||
|
# decompose into folder & file
|
||||||
|
file_name = path.name
|
||||||
|
custom_pipeline = path.parent.absolute()
|
||||||
|
else:
|
||||||
|
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||||
|
|
||||||
|
pipeline_class = get_class_from_dynamic_module(
|
||||||
|
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision
|
||||||
|
)
|
||||||
|
elif cls != DiffusionPipeline:
|
||||||
|
pipeline_class = cls
|
||||||
|
else:
|
||||||
|
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||||
|
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||||
|
|
||||||
|
# To be removed in 1.0.0
|
||||||
|
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||||
|
version.parse(config_dict["_diffusers_version"]).base_version
|
||||||
|
) <= version.parse("0.5.1"):
|
||||||
|
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
||||||
|
|
||||||
|
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
||||||
|
|
||||||
|
deprecation_message = (
|
||||||
|
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
||||||
|
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
||||||
|
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
||||||
|
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
||||||
|
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
||||||
|
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
||||||
|
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
||||||
|
)
|
||||||
|
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
|
||||||
|
# some modules can be passed directly to the init
|
||||||
|
# in this case they are already instantiated in `kwargs`
|
||||||
|
# extract them here
|
||||||
|
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||||
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||||
|
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||||
|
|
||||||
|
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
# define init kwargs
|
||||||
|
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||||
|
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||||
|
|
||||||
|
# remove `null` components
|
||||||
|
def load_module(name, value):
|
||||||
|
if value[0] is None:
|
||||||
|
return False
|
||||||
|
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||||
|
|
||||||
|
if len(unused_kwargs) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `device_map=None`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is False and device_map is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||||
|
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# import it here to avoid circular import
|
||||||
|
from diffusers import pipelines
|
||||||
|
|
||||||
|
# 3. Load each module in the pipeline
|
||||||
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
|
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||||
|
if class_name.startswith("Flax"):
|
||||||
|
class_name = class_name[4:]
|
||||||
|
|
||||||
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
|
loaded_sub_model = None
|
||||||
|
|
||||||
|
# if the model is in a pipeline module, then we load it from the pipeline
|
||||||
|
if name in passed_class_obj:
|
||||||
|
# 1. check that passed_class_obj has correct parent class
|
||||||
|
if not is_pipeline_module:
|
||||||
|
library = importlib.import_module(library_name)
|
||||||
|
class_obj = getattr(library, class_name)
|
||||||
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||||||
|
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||||
|
|
||||||
|
expected_class_obj = None
|
||||||
|
for class_name, class_candidate in class_candidates.items():
|
||||||
|
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||||
|
expected_class_obj = class_candidate
|
||||||
|
|
||||||
|
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||||
|
raise ValueError(
|
||||||
|
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||||
|
f" {expected_class_obj}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||||
|
" has the correct type"
|
||||||
|
)
|
||||||
|
|
||||||
|
# set passed class object
|
||||||
|
loaded_sub_model = passed_class_obj[name]
|
||||||
|
elif is_pipeline_module:
|
||||||
|
pipeline_module = getattr(pipelines, library_name)
|
||||||
|
class_obj = getattr(pipeline_module, class_name)
|
||||||
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||||
|
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||||
|
else:
|
||||||
|
# else we just import it from the library.
|
||||||
|
library = importlib.import_module(library_name)
|
||||||
|
|
||||||
|
class_obj = getattr(library, class_name)
|
||||||
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||||||
|
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||||
|
|
||||||
|
if loaded_sub_model is None:
|
||||||
|
load_method_name = None
|
||||||
|
for class_name, class_candidate in class_candidates.items():
|
||||||
|
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||||
|
load_method_name = importable_classes[class_name][1]
|
||||||
|
|
||||||
|
if load_method_name is None:
|
||||||
|
none_module = class_obj.__module__
|
||||||
|
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
||||||
|
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
||||||
|
)
|
||||||
|
if is_dummy_path and "dummy" in none_module:
|
||||||
|
# call class_obj for nice error message of missing requirements
|
||||||
|
class_obj()
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
||||||
|
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||||
|
)
|
||||||
|
|
||||||
|
load_method = getattr(class_obj, load_method_name)
|
||||||
|
loading_kwargs = {}
|
||||||
|
|
||||||
|
if issubclass(class_obj, torch.nn.Module):
|
||||||
|
loading_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
||||||
|
loading_kwargs["provider"] = provider
|
||||||
|
loading_kwargs["sess_options"] = sess_options
|
||||||
|
|
||||||
|
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
|
||||||
|
is_transformers_model = (
|
||||||
|
is_transformers_available()
|
||||||
|
and issubclass(class_obj, PreTrainedModel)
|
||||||
|
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
||||||
|
)
|
||||||
|
|
||||||
|
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||||
|
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||||
|
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||||
|
if is_diffusers_model or is_transformers_model:
|
||||||
|
loading_kwargs["device_map"] = device_map
|
||||||
|
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||||
|
|
||||||
|
# check if the module is in a subdirectory
|
||||||
|
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||||
|
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||||
|
else:
|
||||||
|
# else load from the root directory
|
||||||
|
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||||
|
|
||||||
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||||
|
|
||||||
|
# 4. Potentially add passed objects if expected
|
||||||
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||||
|
passed_modules = list(passed_class_obj.keys())
|
||||||
|
optional_modules = pipeline_class._optional_components
|
||||||
|
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||||
|
for module in missing_modules:
|
||||||
|
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||||
|
elif len(missing_modules) > 0:
|
||||||
|
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||||
|
raise ValueError(
|
||||||
|
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Instantiate the pipeline
|
||||||
|
model = pipeline_class(**init_kwargs)
|
||||||
|
|
||||||
|
if return_cached_folder:
|
||||||
|
return model, cached_folder
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_signature_keys(obj):
|
||||||
|
parameters = inspect.signature(obj.__init__).parameters
|
||||||
|
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||||
|
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||||
|
expected_modules = set(required_parameters.keys()) - set(["self"])
|
||||||
|
return expected_modules, optional_parameters
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self) -> Dict[str, Any]:
|
||||||
|
r"""
|
||||||
|
|
||||||
|
The `self.components` property can be useful to run different pipelines with the same weights and
|
||||||
|
configurations to not have to re-allocate memory.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from diffusers import (
|
||||||
|
... StableDiffusionPipeline,
|
||||||
|
... StableDiffusionImg2ImgPipeline,
|
||||||
|
... StableDiffusionInpaintPipeline,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||||
|
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
|
||||||
|
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing all the modules needed to initialize the pipeline.
|
||||||
|
"""
|
||||||
|
expected_modules, optional_parameters = self._get_signature_keys(self)
|
||||||
|
components = {
|
||||||
|
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
if set(components.keys()) != expected_modules:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
||||||
|
f" {expected_modules} to be defined, but {components} are defined."
|
||||||
|
)
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def numpy_to_pil(images):
|
||||||
|
"""
|
||||||
|
Convert a numpy image or a batch of images to a PIL image.
|
||||||
|
"""
|
||||||
|
if images.ndim == 3:
|
||||||
|
images = images[None, ...]
|
||||||
|
images = (images * 255).round().astype("uint8")
|
||||||
|
if images.shape[-1] == 1:
|
||||||
|
# special case for grayscale (single channel) images
|
||||||
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
||||||
|
else:
|
||||||
|
pil_images = [Image.fromarray(image) for image in images]
|
||||||
|
|
||||||
|
return pil_images
|
||||||
|
|
||||||
|
def progress_bar(self, iterable=None, total=None):
|
||||||
|
if not hasattr(self, "_progress_bar_config"):
|
||||||
|
self._progress_bar_config = {}
|
||||||
|
elif not isinstance(self._progress_bar_config, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if iterable is not None:
|
||||||
|
return tqdm(iterable, **self._progress_bar_config)
|
||||||
|
elif total is not None:
|
||||||
|
return tqdm(total=total, **self._progress_bar_config)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||||
|
|
||||||
|
def set_progress_bar_config(self, **kwargs):
|
||||||
|
self._progress_bar_config = kwargs
|
||||||
|
|
||||||
|
def enable_xformers_memory_efficient_attention(self):
|
||||||
|
r"""
|
||||||
|
Enable memory efficient attention as implemented in xformers.
|
||||||
|
|
||||||
|
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||||
|
time. Speed up at training time is not guaranteed.
|
||||||
|
|
||||||
|
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||||
|
is used.
|
||||||
|
"""
|
||||||
|
self.set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
||||||
|
def disable_xformers_memory_efficient_attention(self):
|
||||||
|
r"""
|
||||||
|
Disable memory efficient attention as implemented in xformers.
|
||||||
|
"""
|
||||||
|
self.set_use_memory_efficient_attention_xformers(False)
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
||||||
|
# Recursively walk through all the children.
|
||||||
|
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||||
|
# gets the message
|
||||||
|
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||||
|
module.set_use_memory_efficient_attention_xformers(valid)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_mem_eff(child)
|
||||||
|
|
||||||
|
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||||
|
for module_name in module_names:
|
||||||
|
module = getattr(self, module_name)
|
||||||
|
if isinstance(module, torch.nn.Module):
|
||||||
|
fn_recursive_set_mem_eff(module)
|
||||||
|
|
||||||
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||||
|
r"""
|
||||||
|
Enable sliced attention computation.
|
||||||
|
|
||||||
|
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||||
|
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||||
|
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||||
|
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
||||||
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||||
|
must be a multiple of `slice_size`.
|
||||||
|
"""
|
||||||
|
self.set_attention_slice(slice_size)
|
||||||
|
|
||||||
|
def disable_attention_slicing(self):
|
||||||
|
r"""
|
||||||
|
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||||
|
back to computing attention in one step.
|
||||||
|
"""
|
||||||
|
# set slice_size = `None` to disable `attention slicing`
|
||||||
|
self.enable_attention_slicing(None)
|
||||||
|
|
||||||
|
def set_attention_slice(self, slice_size: Optional[int]):
|
||||||
|
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||||
|
for module_name in module_names:
|
||||||
|
module = getattr(self, module_name)
|
||||||
|
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
|
||||||
|
module.set_attention_slice(slice_size)
|
|
@ -18,8 +18,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import PNDMScheduler
|
from ...schedulers import PNDMScheduler
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
class PNDMPipeline(DiffusionPipeline):
|
class PNDMPipeline(DiffusionPipeline):
|
||||||
|
@ -62,12 +62,11 @@ class PNDMPipeline(DiffusionPipeline):
|
||||||
output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
|
output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
|
||||||
between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
|
return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
[`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
# For more information on the sampling method you can take a look at Algorithm 2 of
|
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||||
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||||
|
|
|
@ -21,9 +21,9 @@ import torch
|
||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import RePaintScheduler
|
from ...schedulers import RePaintScheduler
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -118,12 +118,11 @@ class RePaintPipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message = "Please use `image` instead of `original_image`."
|
message = "Please use `image` instead of `original_image`."
|
||||||
|
|
|
@ -17,8 +17,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import ScoreSdeVeScheduler
|
from ...schedulers import ScoreSdeVeScheduler
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||||
|
@ -57,12 +57,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
img_size = self.unet.config.sample_size
|
img_size = self.unet.config.sample_size
|
||||||
|
|
|
@ -25,9 +25,9 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import DDIMScheduler
|
from ...schedulers import DDIMScheduler
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,6 @@ from PIL import Image
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||||
|
|
||||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||||
from ...pipeline_flax_utils import FlaxDiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
FlaxDDIMScheduler,
|
FlaxDDIMScheduler,
|
||||||
FlaxDPMSolverMultistepScheduler,
|
FlaxDPMSolverMultistepScheduler,
|
||||||
|
@ -36,6 +35,7 @@ from ...schedulers import (
|
||||||
FlaxPNDMScheduler,
|
FlaxPNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
|
from ..pipeline_flax_utils import FlaxDiffusionPipeline
|
||||||
from . import FlaxStableDiffusionPipelineOutput
|
from . import FlaxStableDiffusionPipelineOutput
|
||||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,6 @@ from PIL import Image
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||||
|
|
||||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||||
from ...pipeline_flax_utils import FlaxDiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
FlaxDDIMScheduler,
|
FlaxDDIMScheduler,
|
||||||
FlaxDPMSolverMultistepScheduler,
|
FlaxDPMSolverMultistepScheduler,
|
||||||
|
@ -35,6 +34,7 @@ from ...schedulers import (
|
||||||
FlaxPNDMScheduler,
|
FlaxPNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, logging
|
from ...utils import PIL_INTERPOLATION, logging
|
||||||
|
from ..pipeline_flax_utils import FlaxDiffusionPipeline
|
||||||
from . import FlaxStableDiffusionPipelineOutput
|
from . import FlaxStableDiffusionPipelineOutput
|
||||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -21,10 +21,10 @@ import torch
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
|
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,10 @@ import PIL
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,10 @@ import PIL
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,10 +8,10 @@ import PIL
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...onnx_utils import OnnxRuntimeModel
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
|
from ..onnx_utils import OnnxRuntimeModel
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -33,6 +32,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,6 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTF
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -36,6 +35,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
|
@ -24,7 +24,6 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -34,6 +33,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -35,6 +34,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -25,9 +25,9 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -35,6 +34,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ import torch
|
||||||
|
|
||||||
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
||||||
|
|
||||||
from ... import DiffusionPipeline
|
from ...pipelines import DiffusionPipeline
|
||||||
from ...schedulers import LMSDiscreteScheduler
|
from ...schedulers import LMSDiscreteScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
|
|
|
@ -23,9 +23,9 @@ from diffusers.utils import is_accelerate_available
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
|
@ -10,7 +10,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import (
|
from ...schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -20,6 +19,7 @@ from ...schedulers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
)
|
)
|
||||||
from ...utils import deprecate, is_accelerate_available, logging
|
from ...utils import deprecate, is_accelerate_available, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from . import StableDiffusionSafePipelineOutput
|
from . import StableDiffusionSafePipelineOutput
|
||||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import KarrasVeScheduler
|
from ...schedulers import KarrasVeScheduler
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
class KarrasVePipeline(DiffusionPipeline):
|
class KarrasVePipeline(DiffusionPipeline):
|
||||||
|
@ -68,12 +68,11 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
img_size = self.unet.config.sample_size
|
img_size = self.unet.config.sample_size
|
||||||
|
|
|
@ -18,11 +18,11 @@ from typing import List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from diffusers.schedulers import UnCLIPScheduler
|
|
||||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
|
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||||
|
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
from ...schedulers import UnCLIPScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging
|
||||||
from .text_proj import UnCLIPTextProjModel
|
from .text_proj import UnCLIPTextProjModel
|
||||||
|
|
||||||
|
@ -291,7 +291,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||||
The output format of the generated image. Choose between
|
The output format of the generated image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
|
@ -19,9 +19,6 @@ import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from diffusers import UNet2DConditionModel, UNet2DModel
|
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from diffusers.schedulers import UnCLIPScheduler
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CLIPFeatureExtractor,
|
CLIPFeatureExtractor,
|
||||||
CLIPTextModelWithProjection,
|
CLIPTextModelWithProjection,
|
||||||
|
@ -29,6 +26,9 @@ from transformers import (
|
||||||
CLIPVisionModelWithProjection,
|
CLIPVisionModelWithProjection,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...models import UNet2DConditionModel, UNet2DModel
|
||||||
|
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
from ...schedulers import UnCLIPScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging
|
||||||
from .text_proj import UnCLIPTextProjModel
|
from .text_proj import UnCLIPTextProjModel
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||||
The output format of the generated image. Choose between
|
The output format of the generated image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
|
@ -15,9 +15,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from diffusers.modeling_utils import ModelMixin
|
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ...models import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
|
class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...modeling_utils import ModelMixin
|
from ...models import ModelMixin
|
||||||
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
|
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
|
||||||
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
|
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
|
||||||
from ...models.embeddings import TimestepEmbedding, Timesteps
|
from ...models.embeddings import TimestepEmbedding, Timesteps
|
||||||
|
|
|
@ -7,9 +7,9 @@ import PIL.Image
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
|
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
|
||||||
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
|
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
|
||||||
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
||||||
|
|
|
@ -27,11 +27,10 @@ from transformers import (
|
||||||
CLIPVisionModelWithProjection,
|
CLIPVisionModelWithProjection,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
|
||||||
from ...models.attention import DualTransformer2DModel, Transformer2DModel
|
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
from .modeling_text_unet import UNetFlatConditionModel
|
from .modeling_text_unet import UNetFlatConditionModel
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,9 @@ import PIL
|
||||||
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
|
@ -20,11 +20,10 @@ import torch.utils.checkpoint
|
||||||
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
|
||||||
from ...models.attention import Transformer2DModel
|
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from ...utils import is_accelerate_available, logging
|
from ...utils import is_accelerate_available, logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
from .modeling_text_unet import UNetFlatConditionModel
|
from .modeling_text_unet import UNetFlatConditionModel
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,14 +16,13 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import Transformer2DModel, VQModel
|
|
||||||
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...modeling_utils import ModelMixin
|
from ...models import ModelMixin, Transformer2DModel, VQModel
|
||||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ...schedulers import VQDiffusionScheduler
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
@ -212,7 +211,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||||
The output format of the generated image. Choose between
|
The output format of the generated image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
callback (`Callable`, *optional*):
|
callback (`Callable`, *optional*):
|
||||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
|
@ -221,9 +220,8 @@ class VQDiffusionPipeline(DiffusionPipeline):
|
||||||
called at every step.
|
called at every step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict`
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||||
generated images.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
|
@ -18,7 +18,22 @@ import os
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
|
from .constants import (
|
||||||
|
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||||
|
CONFIG_NAME,
|
||||||
|
DIFFUSERS_CACHE,
|
||||||
|
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||||
|
FLAX_WEIGHTS_NAME,
|
||||||
|
HF_MODULES_CACHE,
|
||||||
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
|
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||||
|
ONNX_WEIGHTS_NAME,
|
||||||
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
)
|
||||||
from .deprecation_utils import deprecate
|
from .deprecation_utils import deprecate
|
||||||
|
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||||
|
from .hub_utils import HF_HUB_OFFLINE, http_user_agent
|
||||||
from .import_utils import (
|
from .import_utils import (
|
||||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||||
ENV_VARS_TRUE_VALUES,
|
ENV_VARS_TRUE_VALUES,
|
||||||
|
@ -67,36 +82,6 @@ if is_torch_available():
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
hf_cache_home = os.path.expanduser(
|
|
||||||
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
|
||||||
)
|
|
||||||
default_cache_path = os.path.join(hf_cache_home, "diffusers")
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
|
||||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
|
||||||
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
|
||||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
|
||||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
|
||||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
|
||||||
DIFFUSERS_CACHE = default_cache_path
|
|
||||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
|
||||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
|
||||||
|
|
||||||
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
|
|
||||||
"DDIMScheduler",
|
|
||||||
"DDPMScheduler",
|
|
||||||
"PNDMScheduler",
|
|
||||||
"LMSDiscreteScheduler",
|
|
||||||
"EulerDiscreteScheduler",
|
|
||||||
"HeunDiscreteScheduler",
|
|
||||||
"EulerAncestralDiscreteScheduler",
|
|
||||||
"DPMSolverMultistepScheduler",
|
|
||||||
"DPMSolverSinglestepScheduler",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def check_min_version(min_version):
|
def check_min_version(min_version):
|
||||||
if version.parse(__version__) < version.parse(min_version):
|
if version.parse(__version__) < version.parse(min_version):
|
||||||
if "dev" in min_version:
|
if "dev" in min_version:
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
hf_cache_home = os.path.expanduser(
|
||||||
|
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
||||||
|
)
|
||||||
|
default_cache_path = os.path.join(hf_cache_home, "diffusers")
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_NAME = "config.json"
|
||||||
|
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||||
|
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||||
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||||
|
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||||
|
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||||
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||||
|
DIFFUSERS_CACHE = default_cache_path
|
||||||
|
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||||
|
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||||
|
|
||||||
|
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
|
||||||
|
"DDIMScheduler",
|
||||||
|
"DDPMScheduler",
|
||||||
|
"PNDMScheduler",
|
||||||
|
"LMSDiscreteScheduler",
|
||||||
|
"EulerDiscreteScheduler",
|
||||||
|
"HeunDiscreteScheduler",
|
||||||
|
"EulerAncestralDiscreteScheduler",
|
||||||
|
"DPMSolverMultistepScheduler",
|
||||||
|
"DPMSolverSinglestepScheduler",
|
||||||
|
]
|
|
@ -0,0 +1,19 @@
|
||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxRuntimeModel(metaclass=DummyObject):
|
||||||
|
_backends = ["onnx"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["onnx"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["onnx"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["onnx"])
|
|
@ -4,7 +4,7 @@
|
||||||
from ..utils import DummyObject, requires_backends
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
class ModelMixin(metaclass=DummyObject):
|
class AutoencoderKL(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -19,7 +19,7 @@ class ModelMixin(metaclass=DummyObject):
|
||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKL(metaclass=DummyObject):
|
class ModelMixin(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -152,7 +152,7 @@ def get_scheduler(*args, **kwargs):
|
||||||
requires_backends(get_scheduler, ["torch"])
|
requires_backends(get_scheduler, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPipeline(metaclass=DummyObject):
|
class AudioPipelineOutput(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -212,6 +212,36 @@ class DDPMPipeline(metaclass=DummyObject):
|
||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePipelineOutput(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class KarrasVePipeline(metaclass=DummyObject):
|
class KarrasVePipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,8 @@ from urllib import request
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
|
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
|
||||||
|
|
||||||
from . import __version__
|
from .. import __version__
|
||||||
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
||||||
|
|
||||||
|
|
||||||
COMMUNITY_PIPELINES_URL = (
|
COMMUNITY_PIPELINES_URL = (
|
||||||
|
@ -172,7 +172,7 @@ def find_pipeline_class(loaded_module):
|
||||||
Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
|
Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class
|
||||||
inheriting from `DiffusionPipeline`.
|
inheriting from `DiffusionPipeline`.
|
||||||
"""
|
"""
|
||||||
from .pipeline_utils import DiffusionPipeline
|
from ..pipelines import DiffusionPipeline
|
||||||
|
|
||||||
cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
|
cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass))
|
||||||
|
|
|
@ -22,9 +22,10 @@ from uuid import uuid4
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, whoami
|
from huggingface_hub import HfFolder, whoami
|
||||||
|
|
||||||
from . import __version__
|
from .. import __version__
|
||||||
from .utils import ENV_VARS_TRUE_VALUES, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
from .constants import HUGGINGFACE_CO_RESOLVE_ENDPOINT
|
||||||
from .utils.import_utils import (
|
from .import_utils import (
|
||||||
|
ENV_VARS_TRUE_VALUES,
|
||||||
_flax_version,
|
_flax_version,
|
||||||
_jax_version,
|
_jax_version,
|
||||||
_onnxruntime_version,
|
_onnxruntime_version,
|
||||||
|
@ -34,13 +35,14 @@ from .utils.import_utils import (
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if is_modelcards_available():
|
if is_modelcards_available():
|
||||||
from modelcards import CardData, ModelCard
|
from modelcards import CardData, ModelCard
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
|
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
|
|
@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
|
||||||
class CustomLocalPipeline(DiffusionPipeline):
|
class CustomLocalPipeline(DiffusionPipeline):
|
||||||
|
@ -63,10 +63,10 @@ class CustomLocalPipeline(DiffusionPipeline):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
||||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||||
generated images.
|
generated images.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from diffusers.modeling_utils import ModelMixin
|
from diffusers.models import ModelMixin
|
||||||
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,13 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import UnCLIPImageVariationPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
|
from diffusers import (
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
DiffusionPipeline,
|
||||||
|
UnCLIPImageVariationPipeline,
|
||||||
|
UnCLIPScheduler,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
UNet2DModel,
|
||||||
|
)
|
||||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||||
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
|
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import load_image, require_torch_gpu
|
from diffusers.utils.testing_utils import load_image, require_torch_gpu
|
||||||
|
|
|
@ -21,7 +21,7 @@ from typing import Dict, List, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.modeling_utils import ModelMixin
|
from diffusers.models import ModelMixin
|
||||||
from diffusers.training_utils import EMAModel
|
from diffusers.training_utils import EMAModel
|
||||||
from diffusers.utils import torch_device
|
from diffusers.utils import torch_device
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ from diffusers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMPipeline,
|
DDPMPipeline,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
|
DiffusionPipeline,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
EulerDiscreteScheduler,
|
EulerDiscreteScheduler,
|
||||||
|
@ -45,7 +46,6 @@ from diffusers import (
|
||||||
UNet2DModel,
|
UNet2DModel,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, nightly, slow, torch_device
|
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, nightly, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||||
|
@ -704,7 +704,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_warning_unused_kwargs(self):
|
def test_warning_unused_kwargs(self):
|
||||||
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
||||||
logger = logging.get_logger("diffusers.pipeline_utils")
|
logger = logging.get_logger("diffusers.pipelines")
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
with CaptureLogger(logger) as cap_logger:
|
with CaptureLogger(logger) as cap_logger:
|
||||||
DiffusionPipeline.from_pretrained(
|
DiffusionPipeline.from_pretrained(
|
||||||
|
|
Loading…
Reference in New Issue