2023-01-29 23:51:06 -07:00
from collections import deque
2022-09-03 03:08:45 -06:00
import torch
2022-09-28 01:49:07 -06:00
import inspect
2022-09-03 03:08:45 -06:00
import k_diffusion . sampling
2023-08-08 09:35:31 -06:00
from modules import devices , sd_samplers_common , sd_samplers_extra , sd_samplers_cfg_denoiser
2022-09-03 03:08:45 -06:00
2023-08-05 20:37:25 -06:00
from modules . processing import StableDiffusionProcessing
2023-01-29 23:51:06 -07:00
from modules . shared import opts , state
2022-09-03 03:08:45 -06:00
import modules . shared as shared
2022-09-03 08:21:15 -06:00
samplers_k_diffusion = [
2023-05-16 02:54:02 -06:00
( ' Euler a ' , ' sample_euler_ancestral ' , [ ' k_euler_a ' , ' k_euler_ancestral ' ] , { " uses_ensd " : True } ) ,
2022-10-06 05:12:52 -06:00
( ' Euler ' , ' sample_euler ' , [ ' k_euler ' ] , { } ) ,
( ' LMS ' , ' sample_lms ' , [ ' k_lms ' ] , { } ) ,
2023-05-16 03:36:15 -06:00
( ' Heun ' , ' sample_heun ' , [ ' k_heun ' ] , { " second_order " : True } ) ,
2022-12-23 23:03:45 -07:00
( ' DPM2 ' , ' sample_dpm_2 ' , [ ' k_dpm_2 ' ] , { ' discard_next_to_last_sigma ' : True } ) ,
2023-05-16 02:54:02 -06:00
( ' DPM2 a ' , ' sample_dpm_2_ancestral ' , [ ' k_dpm_2_a ' ] , { ' discard_next_to_last_sigma ' : True , " uses_ensd " : True } ) ,
2023-05-16 03:36:15 -06:00
( ' DPM++ 2S a ' , ' sample_dpmpp_2s_ancestral ' , [ ' k_dpmpp_2s_a ' ] , { " uses_ensd " : True , " second_order " : True } ) ,
2022-11-05 09:32:22 -06:00
( ' DPM++ 2M ' , ' sample_dpmpp_2m ' , [ ' k_dpmpp_2m ' ] , { } ) ,
2023-05-20 22:31:39 -06:00
( ' DPM++ SDE ' , ' sample_dpmpp_sde ' , [ ' k_dpmpp_sde ' ] , { " second_order " : True , " brownian_noise " : True } ) ,
2023-05-22 11:06:57 -06:00
( ' DPM++ 2M SDE ' , ' sample_dpmpp_2m_sde ' , [ ' k_dpmpp_2m_sde_ka ' ] , { " brownian_noise " : True } ) ,
2023-05-16 02:54:02 -06:00
( ' DPM fast ' , ' sample_dpm_fast ' , [ ' k_dpm_fast ' ] , { " uses_ensd " : True } ) ,
( ' DPM adaptive ' , ' sample_dpm_adaptive ' , [ ' k_dpm_ad ' ] , { " uses_ensd " : True } ) ,
2022-10-06 05:12:52 -06:00
( ' LMS Karras ' , ' sample_lms ' , [ ' k_lms_ka ' ] , { ' scheduler ' : ' karras ' } ) ,
2023-05-16 03:36:15 -06:00
( ' DPM2 Karras ' , ' sample_dpm_2 ' , [ ' k_dpm_2_ka ' ] , { ' scheduler ' : ' karras ' , ' discard_next_to_last_sigma ' : True , " uses_ensd " : True , " second_order " : True } ) ,
( ' DPM2 a Karras ' , ' sample_dpm_2_ancestral ' , [ ' k_dpm_2_a_ka ' ] , { ' scheduler ' : ' karras ' , ' discard_next_to_last_sigma ' : True , " uses_ensd " : True , " second_order " : True } ) ,
( ' DPM++ 2S a Karras ' , ' sample_dpmpp_2s_ancestral ' , [ ' k_dpmpp_2s_a_ka ' ] , { ' scheduler ' : ' karras ' , " uses_ensd " : True , " second_order " : True } ) ,
2022-11-05 09:32:22 -06:00
( ' DPM++ 2M Karras ' , ' sample_dpmpp_2m ' , [ ' k_dpmpp_2m_ka ' ] , { ' scheduler ' : ' karras ' } ) ,
2023-05-20 22:31:39 -06:00
( ' DPM++ SDE Karras ' , ' sample_dpmpp_sde ' , [ ' k_dpmpp_sde_ka ' ] , { ' scheduler ' : ' karras ' , " second_order " : True , " brownian_noise " : True } ) ,
2023-05-22 11:06:57 -06:00
( ' DPM++ 2M SDE Karras ' , ' sample_dpmpp_2m_sde ' , [ ' k_dpmpp_2m_sde_ka ' ] , { ' scheduler ' : ' karras ' , " brownian_noise " : True } ) ,
2023-08-03 22:51:49 -06:00
( ' DPM++ 2M SDE Exponential ' , ' sample_dpmpp_2m_sde ' , [ ' k_dpmpp_2m_sde_exp ' ] , { ' scheduler ' : ' exponential ' , " brownian_noise " : True } ) ,
2023-07-28 23:11:59 -06:00
( ' Restart ' , sd_samplers_extra . restart_sampler , [ ' restart ' ] , { ' scheduler ' : ' karras ' } ) ,
2022-09-03 08:21:15 -06:00
]
2023-07-17 22:32:01 -06:00
2022-09-03 08:21:15 -06:00
samplers_data_k_diffusion = [
2023-01-29 23:51:06 -07:00
sd_samplers_common . SamplerData ( label , lambda model , funcname = funcname : KDiffusionSampler ( funcname , model ) , aliases , options )
2022-10-06 05:12:52 -06:00
for label , funcname , aliases , options in samplers_k_diffusion
2023-07-28 23:11:59 -06:00
if callable ( funcname ) or hasattr ( k_diffusion . sampling , funcname )
2022-09-03 08:21:15 -06:00
]
2022-09-26 02:56:47 -06:00
sampler_extra_params = {
2022-09-28 01:49:07 -06:00
' sample_euler ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
' sample_heun ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
' sample_dpm_2 ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
2022-09-26 02:56:47 -06:00
}
2022-09-03 03:08:45 -06:00
2023-05-22 09:26:28 -06:00
k_diffusion_samplers_map = { x . name : x for x in samplers_data_k_diffusion }
2023-05-22 09:02:05 -06:00
k_diffusion_scheduler = {
2023-05-23 10:18:09 -06:00
' Automatic ' : None ,
2023-05-22 09:02:05 -06:00
' karras ' : k_diffusion . sampling . get_sigmas_karras ,
' exponential ' : k_diffusion . sampling . get_sigmas_exponential ,
' polyexponential ' : k_diffusion . sampling . get_sigmas_polyexponential
}
2022-10-22 11:48:13 -06:00
2022-09-16 00:47:03 -06:00
class TorchHijack :
2022-11-25 19:12:23 -07:00
def __init__ ( self , sampler_noises ) :
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self . sampler_noises = deque ( sampler_noises )
2022-09-16 00:47:03 -06:00
def __getattr__ ( self , item ) :
if item == ' randn_like ' :
2022-11-25 19:12:23 -07:00
return self . randn_like
2022-09-16 00:47:03 -06:00
if hasattr ( torch , item ) :
return getattr ( torch , item )
2023-05-09 13:17:58 -06:00
raise AttributeError ( f " ' { type ( self ) . __name__ } ' object has no attribute ' { item } ' " )
2022-09-16 00:47:03 -06:00
2022-11-25 19:12:23 -07:00
def randn_like ( self , x ) :
if self . sampler_noises :
noise = self . sampler_noises . popleft ( )
if noise . shape == x . shape :
return noise
2023-08-02 15:00:23 -06:00
return devices . randn_like ( x )
2022-11-25 19:12:23 -07:00
2022-09-13 12:49:58 -06:00
2022-09-03 03:08:45 -06:00
class KDiffusionSampler :
def __init__ ( self , funcname , sd_model ) :
2022-11-26 06:10:46 -07:00
denoiser = k_diffusion . external . CompVisVDenoiser if sd_model . parameterization == " v " else k_diffusion . external . CompVisDenoiser
self . model_wrap = denoiser ( sd_model , quantize = shared . opts . enable_quantization )
2022-09-03 03:08:45 -06:00
self . funcname = funcname
2023-07-28 23:11:59 -06:00
self . func = funcname if callable ( funcname ) else getattr ( k_diffusion . sampling , self . funcname )
2022-09-28 01:49:07 -06:00
self . extra_params = sampler_extra_params . get ( funcname , [ ] )
2023-08-08 09:35:31 -06:00
self . model_wrap_cfg = sd_samplers_cfg_denoiser . CFGDenoiser ( self . model_wrap )
2022-09-13 12:49:58 -06:00
self . sampler_noises = None
2022-09-19 07:42:56 -06:00
self . stop_at = None
2022-09-28 09:09:06 -06:00
self . eta = None
2023-05-20 22:31:39 -06:00
self . config = None # set by the function calling the constructor
2022-10-18 08:23:38 -06:00
self . last_latent = None
2023-04-29 06:57:09 -06:00
self . s_min_uncond = None
2022-09-03 03:08:45 -06:00
2023-08-05 20:37:25 -06:00
# NOTE: These are also defined in the StableDiffusionProcessing class.
# They should have been here to begin with but we're going to
# leave that class __init__ signature alone.
self . s_churn = 0.0
self . s_tmin = 0.0
self . s_tmax = float ( ' inf ' )
self . s_noise = 1.0
2022-10-19 16:09:43 -06:00
self . conditioning_key = sd_model . model . conditioning_key
2022-09-06 10:33:51 -06:00
def callback_state ( self , d ) :
2022-10-18 08:23:38 -06:00
step = d [ ' i ' ]
latent = d [ " denoised " ]
2023-01-14 06:29:23 -07:00
if opts . live_preview_content == " Combined " :
2023-01-29 23:51:06 -07:00
sd_samplers_common . store_latent ( latent )
2022-10-18 08:23:38 -06:00
self . last_latent = latent
if self . stop_at is not None and step > self . stop_at :
2023-01-29 23:51:06 -07:00
raise sd_samplers_common . InterruptedException
2022-10-18 08:23:38 -06:00
state . sampling_step = step
shared . total_tqdm . update ( )
def launch_sampling ( self , steps , func ) :
state . sampling_steps = steps
state . sampling_step = 0
try :
return func ( )
2023-05-22 10:09:49 -06:00
except RecursionError :
print (
2023-05-22 19:38:30 -06:00
' Encountered RecursionError during sampling, returning last latent. '
' rho >5 with a polyexponential scheduler may cause this error. '
' You should try to use a smaller rho value instead. '
2023-05-22 10:09:49 -06:00
)
return self . last_latent
2023-01-29 23:51:06 -07:00
except sd_samplers_common . InterruptedException :
2022-10-18 08:23:38 -06:00
return self . last_latent
2022-09-06 10:33:51 -06:00
2022-09-13 12:49:58 -06:00
def number_of_needed_noises ( self , p ) :
return p . steps
2023-08-05 20:37:25 -06:00
def initialize ( self , p : StableDiffusionProcessing ) :
2022-09-19 07:42:56 -06:00
self . model_wrap_cfg . mask = p . mask if hasattr ( p , ' mask ' ) else None
self . model_wrap_cfg . nmask = p . nmask if hasattr ( p , ' nmask ' ) else None
2023-01-25 13:25:40 -07:00
self . model_wrap_cfg . step = 0
2023-02-04 01:06:17 -07:00
self . model_wrap_cfg . image_cfg_scale = getattr ( p , ' image_cfg_scale ' , None )
2023-01-30 00:47:09 -07:00
self . eta = p . eta if p . eta is not None else opts . eta_ancestral
2023-03-28 16:18:28 -06:00
self . s_min_uncond = getattr ( p , ' s_min_uncond ' , 0.0 )
2022-09-03 03:08:45 -06:00
2022-11-30 06:02:39 -07:00
k_diffusion . sampling . torch = TorchHijack ( self . sampler_noises if self . sampler_noises is not None else [ ] )
2022-09-16 00:47:03 -06:00
2022-09-26 02:56:47 -06:00
extra_params_kwargs = { }
2022-09-28 01:49:07 -06:00
for param_name in self . extra_params :
if hasattr ( p , param_name ) and param_name in inspect . signature ( self . func ) . parameters :
extra_params_kwargs [ param_name ] = getattr ( p , param_name )
2022-09-26 02:56:47 -06:00
2022-09-28 09:09:06 -06:00
if ' eta ' in inspect . signature ( self . func ) . parameters :
2023-01-30 00:47:09 -07:00
if self . eta != 1.0 :
p . extra_generation_params [ " Eta " ] = self . eta
2022-09-28 09:09:06 -06:00
extra_params_kwargs [ ' eta ' ] = self . eta
2023-08-05 20:37:25 -06:00
if len ( self . extra_params ) > 0 :
2023-08-05 21:42:20 -06:00
s_churn = getattr ( opts , ' s_churn ' , p . s_churn )
s_tmin = getattr ( opts , ' s_tmin ' , p . s_tmin )
2023-08-05 21:50:33 -06:00
s_tmax = getattr ( opts , ' s_tmax ' , p . s_tmax ) or self . s_tmax # 0 = inf
2023-08-05 21:42:20 -06:00
s_noise = getattr ( opts , ' s_noise ' , p . s_noise )
2023-08-05 20:37:25 -06:00
if s_churn != self . s_churn :
extra_params_kwargs [ ' s_churn ' ] = s_churn
p . s_churn = s_churn
p . extra_generation_params [ ' Sigma churn ' ] = s_churn
if s_tmin != self . s_tmin :
2023-08-05 21:42:20 -06:00
extra_params_kwargs [ ' s_tmin ' ] = s_tmin
2023-08-05 20:37:25 -06:00
p . s_tmin = s_tmin
p . extra_generation_params [ ' Sigma tmin ' ] = s_tmin
if s_tmax != self . s_tmax :
2023-08-05 21:42:20 -06:00
extra_params_kwargs [ ' s_tmax ' ] = s_tmax
2023-08-05 20:37:25 -06:00
p . s_tmax = s_tmax
p . extra_generation_params [ ' Sigma tmax ' ] = s_tmax
if s_noise != self . s_noise :
2023-08-05 21:42:20 -06:00
extra_params_kwargs [ ' s_noise ' ] = s_noise
2023-08-05 20:37:25 -06:00
p . s_noise = s_noise
p . extra_generation_params [ ' Sigma noise ' ] = s_noise
2022-09-28 09:09:06 -06:00
return extra_params_kwargs
2022-12-23 23:03:45 -07:00
def get_sigmas ( self , p , steps ) :
2023-01-05 00:43:21 -07:00
discard_next_to_last_sigma = self . config is not None and self . config . options . get ( ' discard_next_to_last_sigma ' , False )
if opts . always_discard_next_to_last_sigma and not discard_next_to_last_sigma :
discard_next_to_last_sigma = True
p . extra_generation_params [ " Discard penultimate sigma " ] = True
steps + = 1 if discard_next_to_last_sigma else 0
2022-12-26 13:49:13 -07:00
2022-09-29 18:46:06 -06:00
if p . sampler_noise_scheduler_override :
2022-10-06 14:27:01 -06:00
sigmas = p . sampler_noise_scheduler_override ( steps )
2023-05-23 10:18:09 -06:00
elif opts . k_sched_type != " Automatic " :
2023-05-24 06:35:58 -06:00
m_sigma_min , m_sigma_max = ( self . model_wrap . sigmas [ 0 ] . item ( ) , self . model_wrap . sigmas [ - 1 ] . item ( ) )
2023-05-27 10:53:09 -06:00
sigma_min , sigma_max = ( 0.1 , 10 ) if opts . use_old_karras_scheduler_sigmas else ( m_sigma_min , m_sigma_max )
2023-05-22 09:02:05 -06:00
sigmas_kwargs = {
2023-05-27 10:53:09 -06:00
' sigma_min ' : sigma_min ,
' sigma_max ' : sigma_max ,
2023-05-22 09:02:05 -06:00
}
2023-05-24 06:35:58 -06:00
sigmas_func = k_diffusion_scheduler [ opts . k_sched_type ]
2023-05-27 10:53:09 -06:00
p . extra_generation_params [ " Schedule type " ] = opts . k_sched_type
if opts . sigma_min != m_sigma_min and opts . sigma_min != 0 :
sigmas_kwargs [ ' sigma_min ' ] = opts . sigma_min
p . extra_generation_params [ " Schedule min sigma " ] = opts . sigma_min
if opts . sigma_max != m_sigma_max and opts . sigma_max != 0 :
sigmas_kwargs [ ' sigma_max ' ] = opts . sigma_max
p . extra_generation_params [ " Schedule max sigma " ] = opts . sigma_max
default_rho = 1. if opts . k_sched_type == " polyexponential " else 7.
if opts . k_sched_type != ' exponential ' and opts . rho != 0 and opts . rho != default_rho :
2023-05-22 21:34:51 -06:00
sigmas_kwargs [ ' rho ' ] = opts . rho
2023-05-27 10:53:09 -06:00
p . extra_generation_params [ " Schedule rho " ] = opts . rho
2023-05-24 06:35:58 -06:00
2023-05-22 09:02:05 -06:00
sigmas = sigmas_func ( n = steps , * * sigmas_kwargs , device = shared . device )
2022-10-06 14:27:01 -06:00
elif self . config is not None and self . config . options . get ( ' scheduler ' , None ) == ' karras ' :
2022-12-31 23:51:37 -07:00
sigma_min , sigma_max = ( 0.1 , 10 ) if opts . use_old_karras_scheduler_sigmas else ( self . model_wrap . sigmas [ 0 ] . item ( ) , self . model_wrap . sigmas [ - 1 ] . item ( ) )
sigmas = k_diffusion . sampling . get_sigmas_karras ( n = steps , sigma_min = sigma_min , sigma_max = sigma_max , device = shared . device )
2023-08-03 22:51:49 -06:00
elif self . config is not None and self . config . options . get ( ' scheduler ' , None ) == ' exponential ' :
m_sigma_min , m_sigma_max = ( self . model_wrap . sigmas [ 0 ] . item ( ) , self . model_wrap . sigmas [ - 1 ] . item ( ) )
sigmas = k_diffusion . sampling . get_sigmas_exponential ( n = steps , sigma_min = m_sigma_min , sigma_max = m_sigma_max , device = shared . device )
2022-09-29 18:46:06 -06:00
else :
2022-10-06 14:27:01 -06:00
sigmas = self . model_wrap . get_sigmas ( steps )
2022-09-28 09:09:06 -06:00
2023-01-05 00:43:21 -07:00
if discard_next_to_last_sigma :
2022-12-18 20:16:42 -07:00
sigmas = torch . cat ( [ sigmas [ : - 2 ] , sigmas [ - 1 : ] ] )
2022-12-23 23:03:45 -07:00
return sigmas
2023-02-15 01:57:18 -07:00
def create_noise_sampler ( self , x , sigmas , p ) :
2023-02-10 19:12:16 -07:00
""" For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes """
if shared . opts . no_dpmpp_sde_batch_determinism :
return None
from k_diffusion . sampling import BrownianTreeNoiseSampler
sigma_min , sigma_max = sigmas [ sigmas > 0 ] . min ( ) , sigmas . max ( )
2023-02-15 01:57:18 -07:00
current_iter_seeds = p . all_seeds [ p . iteration * p . batch_size : ( p . iteration + 1 ) * p . batch_size ]
return BrownianTreeNoiseSampler ( x , sigma_min , sigma_max , seed = current_iter_seeds )
2023-02-10 19:12:16 -07:00
2022-12-23 23:03:45 -07:00
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ) :
2023-01-29 23:51:06 -07:00
steps , t_enc = sd_samplers_common . setup_img2img_steps ( p , steps )
2022-12-23 23:03:45 -07:00
sigmas = self . get_sigmas ( p , steps )
2022-09-28 09:09:06 -06:00
sigma_sched = sigmas [ steps - t_enc - 1 : ]
2022-10-10 17:02:44 -06:00
xi = x + noise * sigma_sched [ 0 ]
2023-05-11 09:28:15 -06:00
2022-10-10 17:02:44 -06:00
extra_params_kwargs = self . initialize ( p )
2023-02-10 19:12:16 -07:00
parameters = inspect . signature ( self . func ) . parameters
if ' sigma_min ' in parameters :
2022-10-10 17:36:00 -06:00
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
2022-10-10 17:02:44 -06:00
extra_params_kwargs [ ' sigma_min ' ] = sigma_sched [ - 2 ]
2023-02-10 19:12:16 -07:00
if ' sigma_max ' in parameters :
2022-10-10 17:02:44 -06:00
extra_params_kwargs [ ' sigma_max ' ] = sigma_sched [ 0 ]
2023-02-10 19:12:16 -07:00
if ' n ' in parameters :
2022-10-10 17:02:44 -06:00
extra_params_kwargs [ ' n ' ] = len ( sigma_sched ) - 1
2023-02-10 19:12:16 -07:00
if ' sigma_sched ' in parameters :
2022-10-10 17:02:44 -06:00
extra_params_kwargs [ ' sigma_sched ' ] = sigma_sched
2023-02-10 19:12:16 -07:00
if ' sigmas ' in parameters :
2022-10-10 17:02:44 -06:00
extra_params_kwargs [ ' sigmas ' ] = sigma_sched
2022-09-28 09:09:06 -06:00
2023-05-20 22:31:39 -06:00
if self . config . options . get ( ' brownian_noise ' , False ) :
2023-02-15 01:57:18 -07:00
noise_sampler = self . create_noise_sampler ( x , sigmas , p )
2023-02-10 19:12:16 -07:00
extra_params_kwargs [ ' noise_sampler ' ] = noise_sampler
2022-09-28 09:09:06 -06:00
self . model_wrap_cfg . init_latent = x
2022-10-20 14:49:14 -06:00
self . last_latent = x
2023-05-20 22:31:39 -06:00
extra_args = {
2023-05-11 09:28:15 -06:00
' cond ' : conditioning ,
' image_cond ' : image_conditioning ,
' uncond ' : unconditional_conditioning ,
2023-02-03 16:19:56 -07:00
' cond_scale ' : p . cfg_scale ,
2023-03-28 16:18:28 -06:00
' s_min_uncond ' : self . s_min_uncond
2023-02-03 16:19:56 -07:00
}
samples = self . launch_sampling ( t_enc + 1 , lambda : self . func ( self . model_wrap_cfg , xi , extra_args = extra_args , disable = False , callback = self . callback_state , * * extra_params_kwargs ) )
2022-10-10 17:02:44 -06:00
2023-06-26 21:18:43 -06:00
if self . model_wrap_cfg . padded_cond_uncond :
p . extra_generation_params [ " Pad conds " ] = True
2022-10-18 08:23:38 -06:00
return samples
2022-09-03 03:08:45 -06:00
2023-02-10 19:12:16 -07:00
def sample ( self , p , x , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ) :
2022-09-19 07:42:56 -06:00
steps = steps or p . steps
2022-12-23 23:03:45 -07:00
sigmas = self . get_sigmas ( p , steps )
2022-10-06 05:12:52 -06:00
2022-09-03 03:08:45 -06:00
x = x * sigmas [ 0 ]
2022-09-28 09:09:06 -06:00
extra_params_kwargs = self . initialize ( p )
2023-02-10 19:12:16 -07:00
parameters = inspect . signature ( self . func ) . parameters
if ' sigma_min ' in parameters :
2022-09-29 04:30:33 -06:00
extra_params_kwargs [ ' sigma_min ' ] = self . model_wrap . sigmas [ 0 ] . item ( )
extra_params_kwargs [ ' sigma_max ' ] = self . model_wrap . sigmas [ - 1 ] . item ( )
2023-02-10 19:12:16 -07:00
if ' n ' in parameters :
2022-09-29 04:30:33 -06:00
extra_params_kwargs [ ' n ' ] = steps
else :
extra_params_kwargs [ ' sigmas ' ] = sigmas
2022-10-18 08:23:38 -06:00
2023-05-20 22:31:39 -06:00
if self . config . options . get ( ' brownian_noise ' , False ) :
2023-02-15 01:57:18 -07:00
noise_sampler = self . create_noise_sampler ( x , sigmas , p )
2023-02-10 19:12:16 -07:00
extra_params_kwargs [ ' noise_sampler ' ] = noise_sampler
2022-10-20 14:49:14 -06:00
self . last_latent = x
2022-10-19 14:47:45 -06:00
samples = self . launch_sampling ( steps , lambda : self . func ( self . model_wrap_cfg , x , extra_args = {
2023-05-11 09:28:15 -06:00
' cond ' : conditioning ,
' image_cond ' : image_conditioning ,
' uncond ' : unconditional_conditioning ,
2023-03-28 16:18:28 -06:00
' cond_scale ' : p . cfg_scale ,
' s_min_uncond ' : self . s_min_uncond
2022-10-19 14:47:45 -06:00
} , disable = False , callback = self . callback_state , * * extra_params_kwargs ) )
2022-10-18 08:23:38 -06:00
2023-06-26 21:18:43 -06:00
if self . model_wrap_cfg . padded_cond_uncond :
p . extra_generation_params [ " Pad conds " ] = True
2022-09-19 07:42:56 -06:00
return samples
2022-09-03 03:08:45 -06:00