2022-10-21 09:43:19 -06:00
import inspect
import os
import random
import re
from dataclasses import dataclass
from typing import Callable , Dict , List , Optional , Union
import torch
from diffusers . configuration_utils import FrozenDict
from diffusers . models import AutoencoderKL , UNet2DConditionModel
from diffusers . pipeline_utils import DiffusionPipeline
from diffusers . pipelines . stable_diffusion . pipeline_stable_diffusion import StableDiffusionPipelineOutput
from diffusers . pipelines . stable_diffusion . safety_checker import StableDiffusionSafetyChecker
from diffusers . schedulers import DDIMScheduler , LMSDiscreteScheduler , PNDMScheduler
from diffusers . utils import deprecate , logging
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
logger = logging . get_logger ( __name__ ) # pylint: disable=invalid-name
global_re_wildcard = re . compile ( r " __([^_]*)__ " )
def get_filename ( path : str ) :
# this doesn't work on Windows
return os . path . basename ( path ) . split ( " .txt " ) [ 0 ]
def read_wildcard_values ( path : str ) :
with open ( path , encoding = " utf8 " ) as f :
return f . read ( ) . splitlines ( )
def grab_wildcard_values ( wildcard_option_dict : Dict [ str , List [ str ] ] = { } , wildcard_files : List [ str ] = [ ] ) :
for wildcard_file in wildcard_files :
filename = get_filename ( wildcard_file )
read_values = read_wildcard_values ( wildcard_file )
if filename not in wildcard_option_dict :
wildcard_option_dict [ filename ] = [ ]
wildcard_option_dict [ filename ] . extend ( read_values )
return wildcard_option_dict
def replace_prompt_with_wildcards (
prompt : str , wildcard_option_dict : Dict [ str , List [ str ] ] = { } , wildcard_files : List [ str ] = [ ]
) :
new_prompt = prompt
# get wildcard options
wildcard_option_dict = grab_wildcard_values ( wildcard_option_dict , wildcard_files )
for m in global_re_wildcard . finditer ( new_prompt ) :
wildcard_value = m . group ( )
replace_value = random . choice ( wildcard_option_dict [ wildcard_value . strip ( " __ " ) ] )
new_prompt = new_prompt . replace ( wildcard_value , replace_value , 1 )
return new_prompt
@dataclass
class WildcardStableDiffusionOutput ( StableDiffusionPipelineOutput ) :
prompts : List [ str ]
class WildcardStableDiffusionPipeline ( DiffusionPipeline ) :
r """
Example Usage :
pipe = WildcardStableDiffusionPipeline . from_pretrained (
" CompVis/stable-diffusion-v1-4 " ,
revision = " fp16 " ,
torch_dtype = torch . float16 ,
)
prompt = " __animal__ sitting on a __object__ wearing a __clothing__ "
out = pipe (
prompt ,
wildcard_option_dict = {
" clothing " : [ " hat " , " shirt " , " scarf " , " beret " ]
} ,
wildcard_files = [ " object.txt " , " animal.txt " ] ,
num_prompt_samples = 1
)
Pipeline for text - to - image generation with wild cards using Stable Diffusion .
This model inherits from [ ` DiffusionPipeline ` ] . Check the superclass documentation for the generic methods the
library implements for all the pipelines ( such as downloading or saving , running on a particular device , etc . )
Args :
vae ( [ ` AutoencoderKL ` ] ) :
Variational Auto - Encoder ( VAE ) Model to encode and decode images to and from latent representations .
text_encoder ( [ ` CLIPTextModel ` ] ) :
Frozen text - encoder . Stable Diffusion uses the text portion of
[ CLIP ] ( https : / / huggingface . co / docs / transformers / model_doc / clip #transformers.CLIPTextModel), specifically
the [ clip - vit - large - patch14 ] ( https : / / huggingface . co / openai / clip - vit - large - patch14 ) variant .
tokenizer ( ` CLIPTokenizer ` ) :
Tokenizer of class
[ CLIPTokenizer ] ( https : / / huggingface . co / docs / transformers / v4 .21 .0 / en / model_doc / clip #transformers.CLIPTokenizer).
unet ( [ ` UNet2DConditionModel ` ] ) : Conditional U - Net architecture to denoise the encoded image latents .
scheduler ( [ ` SchedulerMixin ` ] ) :
2022-11-07 05:34:45 -07:00
A scheduler to be used in combination with ` unet ` to denoise the encoded image latents . Can be one of
2022-10-21 09:43:19 -06:00
[ ` DDIMScheduler ` ] , [ ` LMSDiscreteScheduler ` ] , or [ ` PNDMScheduler ` ] .
safety_checker ( [ ` StableDiffusionSafetyChecker ` ] ) :
Classification module that estimates whether generated images could be considered offensive or harmful .
Please , refer to the [ model card ] ( https : / / huggingface . co / CompVis / stable - diffusion - v1 - 4 ) for details .
feature_extractor ( [ ` CLIPFeatureExtractor ` ] ) :
Model that extracts features from generated images to be used as inputs for the ` safety_checker ` .
"""
def __init__ (
self ,
vae : AutoencoderKL ,
text_encoder : CLIPTextModel ,
tokenizer : CLIPTokenizer ,
unet : UNet2DConditionModel ,
scheduler : Union [ DDIMScheduler , PNDMScheduler , LMSDiscreteScheduler ] ,
safety_checker : StableDiffusionSafetyChecker ,
feature_extractor : CLIPFeatureExtractor ,
) :
super ( ) . __init__ ( )
if hasattr ( scheduler . config , " steps_offset " ) and scheduler . config . steps_offset != 1 :
deprecation_message = (
f " The configuration file of this scheduler: { scheduler } is outdated. `steps_offset` "
f " should be set to 1 instead of { scheduler . config . steps_offset } . Please make sure "
" to update the config accordingly as leaving `steps_offset` might led to incorrect results "
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, "
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json` "
" file "
)
deprecate ( " steps_offset!=1 " , " 1.0.0 " , deprecation_message , standard_warn = False )
new_config = dict ( scheduler . config )
new_config [ " steps_offset " ] = 1
scheduler . _internal_dict = FrozenDict ( new_config )
if safety_checker is None :
2022-11-22 12:44:34 -07:00
logger . warning (
2022-10-21 09:43:19 -06:00
f " You have disabled the safety checker for { self . __class__ } by passing `safety_checker=None`. Ensure "
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered "
" results in services or applications open to the public. Both the diffusers team and Hugging Face "
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling "
" it only for use-cases that involve analyzing network behavior or auditing its results. For more "
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 . "
)
self . register_modules (
vae = vae ,
text_encoder = text_encoder ,
tokenizer = tokenizer ,
unet = unet ,
scheduler = scheduler ,
safety_checker = safety_checker ,
feature_extractor = feature_extractor ,
)
@torch.no_grad ( )
def __call__ (
self ,
prompt : Union [ str , List [ str ] ] ,
height : int = 512 ,
width : int = 512 ,
num_inference_steps : int = 50 ,
guidance_scale : float = 7.5 ,
negative_prompt : Optional [ Union [ str , List [ str ] ] ] = None ,
num_images_per_prompt : Optional [ int ] = 1 ,
eta : float = 0.0 ,
generator : Optional [ torch . Generator ] = None ,
latents : Optional [ torch . FloatTensor ] = None ,
output_type : Optional [ str ] = " pil " ,
return_dict : bool = True ,
callback : Optional [ Callable [ [ int , int , torch . FloatTensor ] , None ] ] = None ,
callback_steps : Optional [ int ] = 1 ,
wildcard_option_dict : Dict [ str , List [ str ] ] = { } ,
wildcard_files : List [ str ] = [ ] ,
num_prompt_samples : Optional [ int ] = 1 ,
* * kwargs ,
) :
r """
Function invoked when calling the pipeline for generation .
Args :
prompt ( ` str ` or ` List [ str ] ` ) :
The prompt or prompts to guide the image generation .
height ( ` int ` , * optional * , defaults to 512 ) :
The height in pixels of the generated image .
width ( ` int ` , * optional * , defaults to 512 ) :
The width in pixels of the generated image .
num_inference_steps ( ` int ` , * optional * , defaults to 50 ) :
The number of denoising steps . More denoising steps usually lead to a higher quality image at the
expense of slower inference .
guidance_scale ( ` float ` , * optional * , defaults to 7.5 ) :
Guidance scale as defined in [ Classifier - Free Diffusion Guidance ] ( https : / / arxiv . org / abs / 2207.12598 ) .
` guidance_scale ` is defined as ` w ` of equation 2. of [ Imagen
Paper ] ( https : / / arxiv . org / pdf / 2205.11487 . pdf ) . Guidance scale is enabled by setting ` guidance_scale >
1 ` . Higher guidance scale encourages to generate images that are closely linked to the text ` prompt ` ,
usually at the expense of lower image quality .
negative_prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
The prompt or prompts not to guide the image generation . Ignored when not using guidance ( i . e . , ignored
if ` guidance_scale ` is less than ` 1 ` ) .
num_images_per_prompt ( ` int ` , * optional * , defaults to 1 ) :
The number of images to generate per prompt .
eta ( ` float ` , * optional * , defaults to 0.0 ) :
Corresponds to parameter eta ( η ) in the DDIM paper : https : / / arxiv . org / abs / 2010.02502 . Only applies to
[ ` schedulers . DDIMScheduler ` ] , will be ignored for others .
generator ( ` torch . Generator ` , * optional * ) :
A [ torch generator ] ( https : / / pytorch . org / docs / stable / generated / torch . Generator . html ) to make generation
deterministic .
latents ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated noisy latents , sampled from a Gaussian distribution , to be used as inputs for image
generation . Can be used to tweak the same generation with different prompts . If not provided , a latents
tensor will ge generated by sampling using the supplied random ` generator ` .
output_type ( ` str ` , * optional * , defaults to ` " pil " ` ) :
The output format of the generate image . Choose between
[ PIL ] ( https : / / pillow . readthedocs . io / en / stable / ) : ` PIL . Image . Image ` or ` np . array ` .
return_dict ( ` bool ` , * optional * , defaults to ` True ` ) :
Whether or not to return a [ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] instead of a
plain tuple .
callback ( ` Callable ` , * optional * ) :
A function that will be called every ` callback_steps ` steps during inference . The function will be
called with the following arguments : ` callback ( step : int , timestep : int , latents : torch . FloatTensor ) ` .
callback_steps ( ` int ` , * optional * , defaults to 1 ) :
The frequency at which the ` callback ` function will be called . If not specified , the callback will be
called at every step .
wildcard_option_dict ( Dict [ str , List [ str ] ] ) :
dict with key as ` wildcard ` and values as a list of possible replacements . For example if a prompt , " A __animal__ sitting on a chair " . A wildcard_option_dict can provide possible values for " animal " like this : { " animal " : [ " dog " , " cat " , " fox " ] }
wildcard_files : ( List [ str ] )
List of filenames of txt files for wildcard replacements . For example if a prompt , " A __animal__ sitting on a chair " . A file can be provided [ " animal.txt " ]
num_prompt_samples : int
Number of times to sample wildcards for each prompt provided
Returns :
[ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] or ` tuple ` :
[ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] if ` return_dict ` is True , otherwise a ` tuple .
When returning a tuple , the first element is a list with the generated images , and the second element is a
list of ` bool ` s denoting whether the corresponding generated image likely represents " not-safe-for-work "
( nsfw ) content , according to the ` safety_checker ` .
"""
if isinstance ( prompt , str ) :
prompt = [
replace_prompt_with_wildcards ( prompt , wildcard_option_dict , wildcard_files )
for i in range ( num_prompt_samples )
]
batch_size = len ( prompt )
elif isinstance ( prompt , list ) :
prompt_list = [ ]
for p in prompt :
for i in range ( num_prompt_samples ) :
prompt_list . append ( replace_prompt_with_wildcards ( p , wildcard_option_dict , wildcard_files ) )
prompt = prompt_list
batch_size = len ( prompt )
else :
raise ValueError ( f " `prompt` has to be of type `str` or `list` but is { type ( prompt ) } " )
if height % 8 != 0 or width % 8 != 0 :
raise ValueError ( f " `height` and `width` have to be divisible by 8 but are { height } and { width } . " )
if ( callback_steps is None ) or (
callback_steps is not None and ( not isinstance ( callback_steps , int ) or callback_steps < = 0 )
) :
raise ValueError (
f " `callback_steps` has to be a positive integer but is { callback_steps } of type "
f " { type ( callback_steps ) } . "
)
# get prompt text embeddings
text_inputs = self . tokenizer (
prompt ,
padding = " max_length " ,
max_length = self . tokenizer . model_max_length ,
return_tensors = " pt " ,
)
text_input_ids = text_inputs . input_ids
if text_input_ids . shape [ - 1 ] > self . tokenizer . model_max_length :
removed_text = self . tokenizer . batch_decode ( text_input_ids [ : , self . tokenizer . model_max_length : ] )
logger . warning (
" The following part of your input was truncated because CLIP can only handle sequences up to "
f " { self . tokenizer . model_max_length } tokens: { removed_text } "
)
text_input_ids = text_input_ids [ : , : self . tokenizer . model_max_length ]
text_embeddings = self . text_encoder ( text_input_ids . to ( self . device ) ) [ 0 ]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed , seq_len , _ = text_embeddings . shape
text_embeddings = text_embeddings . repeat ( 1 , num_images_per_prompt , 1 )
text_embeddings = text_embeddings . view ( bs_embed * num_images_per_prompt , seq_len , - 1 )
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance :
uncond_tokens : List [ str ]
if negative_prompt is None :
2022-11-03 08:49:20 -06:00
uncond_tokens = [ " " ] * batch_size
2022-10-21 09:43:19 -06:00
elif type ( prompt ) is not type ( negative_prompt ) :
raise TypeError (
f " `negative_prompt` should be the same type to `prompt`, but got { type ( negative_prompt ) } != "
f " { type ( prompt ) } . "
)
elif isinstance ( negative_prompt , str ) :
uncond_tokens = [ negative_prompt ]
elif batch_size != len ( negative_prompt ) :
raise ValueError (
f " `negative_prompt`: { negative_prompt } has batch size { len ( negative_prompt ) } , but `prompt`: "
f " { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches "
" the batch size of `prompt`. "
)
else :
uncond_tokens = negative_prompt
max_length = text_input_ids . shape [ - 1 ]
uncond_input = self . tokenizer (
uncond_tokens ,
padding = " max_length " ,
max_length = max_length ,
truncation = True ,
return_tensors = " pt " ,
)
uncond_embeddings = self . text_encoder ( uncond_input . input_ids . to ( self . device ) ) [ 0 ]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings . shape [ 1 ]
2022-11-03 08:49:20 -06:00
uncond_embeddings = uncond_embeddings . repeat ( 1 , num_images_per_prompt , 1 )
2022-10-21 09:43:19 -06:00
uncond_embeddings = uncond_embeddings . view ( batch_size * num_images_per_prompt , seq_len , - 1 )
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch . cat ( [ uncond_embeddings , text_embeddings ] )
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = ( batch_size * num_images_per_prompt , self . unet . in_channels , height / / 8 , width / / 8 )
latents_dtype = text_embeddings . dtype
if latents is None :
if self . device . type == " mps " :
# randn does not exist on mps
latents = torch . randn ( latents_shape , generator = generator , device = " cpu " , dtype = latents_dtype ) . to (
self . device
)
else :
latents = torch . randn ( latents_shape , generator = generator , device = self . device , dtype = latents_dtype )
else :
if latents . shape != latents_shape :
raise ValueError ( f " Unexpected latents shape, got { latents . shape } , expected { latents_shape } " )
latents = latents . to ( self . device )
# set timesteps
self . scheduler . set_timesteps ( num_inference_steps )
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self . scheduler . timesteps . to ( self . device )
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self . scheduler . init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = " eta " in set ( inspect . signature ( self . scheduler . step ) . parameters . keys ( ) )
extra_step_kwargs = { }
if accepts_eta :
extra_step_kwargs [ " eta " ] = eta
for i , t in enumerate ( self . progress_bar ( timesteps_tensor ) ) :
# expand the latents if we are doing classifier free guidance
latent_model_input = torch . cat ( [ latents ] * 2 ) if do_classifier_free_guidance else latents
latent_model_input = self . scheduler . scale_model_input ( latent_model_input , t )
# predict the noise residual
noise_pred = self . unet ( latent_model_input , t , encoder_hidden_states = text_embeddings ) . sample
# perform guidance
if do_classifier_free_guidance :
noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond )
# compute the previous noisy sample x_t -> x_t-1
latents = self . scheduler . step ( noise_pred , t , latents , * * extra_step_kwargs ) . prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0 :
callback ( i , t , latents )
latents = 1 / 0.18215 * latents
image = self . vae . decode ( latents ) . sample
image = ( image / 2 + 0.5 ) . clamp ( 0 , 1 )
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image . cpu ( ) . permute ( 0 , 2 , 3 , 1 ) . float ( ) . numpy ( )
if self . safety_checker is not None :
safety_checker_input = self . feature_extractor ( self . numpy_to_pil ( image ) , return_tensors = " pt " ) . to (
self . device
)
image , has_nsfw_concept = self . safety_checker (
images = image , clip_input = safety_checker_input . pixel_values . to ( text_embeddings . dtype )
)
else :
has_nsfw_concept = None
if output_type == " pil " :
image = self . numpy_to_pil ( image )
if not return_dict :
return ( image , has_nsfw_concept )
return WildcardStableDiffusionOutput ( images = image , nsfw_content_detected = has_nsfw_concept , prompts = prompt )