2022-10-17 06:37:25 -06:00
from typing import Any , Callable , Dict , List , Optional , Union
import torch
import PIL . Image
from diffusers import (
AutoencoderKL ,
DDIMScheduler ,
DiffusionPipeline ,
LMSDiscreteScheduler ,
PNDMScheduler ,
StableDiffusionImg2ImgPipeline ,
2022-10-19 07:54:07 -06:00
StableDiffusionInpaintPipelineLegacy ,
2022-10-17 06:37:25 -06:00
StableDiffusionPipeline ,
UNet2DConditionModel ,
)
from diffusers . configuration_utils import FrozenDict
from diffusers . pipelines . stable_diffusion . safety_checker import StableDiffusionSafetyChecker
from diffusers . utils import deprecate , logging
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
logger = logging . get_logger ( __name__ ) # pylint: disable=invalid-name
class StableDiffusionMegaPipeline ( DiffusionPipeline ) :
r """
Pipeline for text - to - image generation using Stable Diffusion .
This model inherits from [ ` DiffusionPipeline ` ] . Check the superclass documentation for the generic methods the
library implements for all the pipelines ( such as downloading or saving , running on a particular device , etc . )
Args :
vae ( [ ` AutoencoderKL ` ] ) :
Variational Auto - Encoder ( VAE ) Model to encode and decode images to and from latent representations .
text_encoder ( [ ` CLIPTextModel ` ] ) :
Frozen text - encoder . Stable Diffusion uses the text portion of
[ CLIP ] ( https : / / huggingface . co / docs / transformers / model_doc / clip #transformers.CLIPTextModel), specifically
the [ clip - vit - large - patch14 ] ( https : / / huggingface . co / openai / clip - vit - large - patch14 ) variant .
tokenizer ( ` CLIPTokenizer ` ) :
Tokenizer of class
[ CLIPTokenizer ] ( https : / / huggingface . co / docs / transformers / v4 .21 .0 / en / model_doc / clip #transformers.CLIPTokenizer).
unet ( [ ` UNet2DConditionModel ` ] ) : Conditional U - Net architecture to denoise the encoded image latents .
scheduler ( [ ` SchedulerMixin ` ] ) :
A scheduler to be used in combination with ` unet ` to denoise the encoded image latens . Can be one of
[ ` DDIMScheduler ` ] , [ ` LMSDiscreteScheduler ` ] , or [ ` PNDMScheduler ` ] .
safety_checker ( [ ` StableDiffusionMegaSafetyChecker ` ] ) :
Classification module that estimates whether generated images could be considered offensive or harmful .
Please , refer to the [ model card ] ( https : / / huggingface . co / CompVis / stable - diffusion - v1 - 4 ) for details .
feature_extractor ( [ ` CLIPFeatureExtractor ` ] ) :
Model that extracts features from generated images to be used as inputs for the ` safety_checker ` .
"""
def __init__ (
self ,
vae : AutoencoderKL ,
text_encoder : CLIPTextModel ,
tokenizer : CLIPTokenizer ,
unet : UNet2DConditionModel ,
scheduler : Union [ DDIMScheduler , PNDMScheduler , LMSDiscreteScheduler ] ,
safety_checker : StableDiffusionSafetyChecker ,
feature_extractor : CLIPFeatureExtractor ,
) :
super ( ) . __init__ ( )
if hasattr ( scheduler . config , " steps_offset " ) and scheduler . config . steps_offset != 1 :
deprecation_message = (
f " The configuration file of this scheduler: { scheduler } is outdated. `steps_offset` "
f " should be set to 1 instead of { scheduler . config . steps_offset } . Please make sure "
" to update the config accordingly as leaving `steps_offset` might led to incorrect results "
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, "
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json` "
" file "
)
deprecate ( " steps_offset!=1 " , " 1.0.0 " , deprecation_message , standard_warn = False )
new_config = dict ( scheduler . config )
new_config [ " steps_offset " ] = 1
scheduler . _internal_dict = FrozenDict ( new_config )
self . register_modules (
vae = vae ,
text_encoder = text_encoder ,
tokenizer = tokenizer ,
unet = unet ,
scheduler = scheduler ,
safety_checker = safety_checker ,
feature_extractor = feature_extractor ,
)
@property
def components ( self ) - > Dict [ str , Any ] :
return { k : getattr ( self , k ) for k in self . config . keys ( ) if not k . startswith ( " _ " ) }
def enable_attention_slicing ( self , slice_size : Optional [ Union [ str , int ] ] = " auto " ) :
r """
Enable sliced attention computation .
When this option is enabled , the attention module will split the input tensor in slices , to compute attention
in several steps . This is useful to save some memory in exchange for a small speed decrease .
Args :
slice_size ( ` str ` or ` int ` , * optional * , defaults to ` " auto " ` ) :
When ` " auto " ` , halves the input to the attention heads , so attention will be computed in two steps . If
a number is provided , uses as many slices as ` attention_head_dim / / slice_size ` . In this case ,
` attention_head_dim ` must be a multiple of ` slice_size ` .
"""
if slice_size == " auto " :
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self . unet . config . attention_head_dim / / 2
self . unet . set_attention_slice ( slice_size )
def disable_attention_slicing ( self ) :
r """
Disable sliced attention computation . If ` enable_attention_slicing ` was previously invoked , this method will go
back to computing attention in one step .
"""
# set slice_size = `None` to disable `attention slicing`
self . enable_attention_slicing ( None )
@torch.no_grad ( )
def inpaint (
self ,
prompt : Union [ str , List [ str ] ] ,
init_image : Union [ torch . FloatTensor , PIL . Image . Image ] ,
mask_image : Union [ torch . FloatTensor , PIL . Image . Image ] ,
strength : float = 0.8 ,
num_inference_steps : Optional [ int ] = 50 ,
guidance_scale : Optional [ float ] = 7.5 ,
negative_prompt : Optional [ Union [ str , List [ str ] ] ] = None ,
num_images_per_prompt : Optional [ int ] = 1 ,
eta : Optional [ float ] = 0.0 ,
generator : Optional [ torch . Generator ] = None ,
output_type : Optional [ str ] = " pil " ,
return_dict : bool = True ,
callback : Optional [ Callable [ [ int , int , torch . FloatTensor ] , None ] ] = None ,
callback_steps : Optional [ int ] = 1 ,
) :
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
2022-10-19 07:54:07 -06:00
return StableDiffusionInpaintPipelineLegacy ( * * self . components ) (
2022-10-17 06:37:25 -06:00
prompt = prompt ,
init_image = init_image ,
mask_image = mask_image ,
strength = strength ,
num_inference_steps = num_inference_steps ,
guidance_scale = guidance_scale ,
negative_prompt = negative_prompt ,
num_images_per_prompt = num_images_per_prompt ,
eta = eta ,
generator = generator ,
output_type = output_type ,
return_dict = return_dict ,
callback = callback ,
)
@torch.no_grad ( )
def img2img (
self ,
prompt : Union [ str , List [ str ] ] ,
init_image : Union [ torch . FloatTensor , PIL . Image . Image ] ,
strength : float = 0.8 ,
num_inference_steps : Optional [ int ] = 50 ,
guidance_scale : Optional [ float ] = 7.5 ,
negative_prompt : Optional [ Union [ str , List [ str ] ] ] = None ,
num_images_per_prompt : Optional [ int ] = 1 ,
eta : Optional [ float ] = 0.0 ,
generator : Optional [ torch . Generator ] = None ,
output_type : Optional [ str ] = " pil " ,
return_dict : bool = True ,
callback : Optional [ Callable [ [ int , int , torch . FloatTensor ] , None ] ] = None ,
callback_steps : Optional [ int ] = 1 ,
* * kwargs ,
) :
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
return StableDiffusionImg2ImgPipeline ( * * self . components ) (
prompt = prompt ,
init_image = init_image ,
strength = strength ,
num_inference_steps = num_inference_steps ,
guidance_scale = guidance_scale ,
negative_prompt = negative_prompt ,
num_images_per_prompt = num_images_per_prompt ,
eta = eta ,
generator = generator ,
output_type = output_type ,
return_dict = return_dict ,
callback = callback ,
callback_steps = callback_steps ,
)
@torch.no_grad ( )
def text2img (
self ,
prompt : Union [ str , List [ str ] ] ,
height : int = 512 ,
width : int = 512 ,
num_inference_steps : int = 50 ,
guidance_scale : float = 7.5 ,
negative_prompt : Optional [ Union [ str , List [ str ] ] ] = None ,
num_images_per_prompt : Optional [ int ] = 1 ,
eta : float = 0.0 ,
generator : Optional [ torch . Generator ] = None ,
latents : Optional [ torch . FloatTensor ] = None ,
output_type : Optional [ str ] = " pil " ,
return_dict : bool = True ,
callback : Optional [ Callable [ [ int , int , torch . FloatTensor ] , None ] ] = None ,
callback_steps : Optional [ int ] = 1 ,
) :
# For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
return StableDiffusionPipeline ( * * self . components ) (
prompt = prompt ,
height = height ,
width = width ,
num_inference_steps = num_inference_steps ,
guidance_scale = guidance_scale ,
negative_prompt = negative_prompt ,
num_images_per_prompt = num_images_per_prompt ,
eta = eta ,
generator = generator ,
latents = latents ,
output_type = output_type ,
return_dict = return_dict ,
callback = callback ,
callback_steps = callback_steps ,
)