emergency fix for #1199
This commit is contained in:
parent
15f333a266
commit
2ab64ec81a
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
import inspect
|
||||
|
||||
import k_diffusion.sampling
|
||||
import ldm.models.diffusion.ddim
|
||||
|
@ -278,9 +279,9 @@ class KDiffusionSampler:
|
|||
k_diffusion.sampling.torch = TorchHijack(self)
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for val in self.extra_params:
|
||||
if hasattr(p,val):
|
||||
extra_params_kwargs[val] = getattr(p,val)
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||
|
||||
|
@ -300,9 +301,9 @@ class KDiffusionSampler:
|
|||
k_diffusion.sampling.torch = TorchHijack(self)
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for val in self.extra_params:
|
||||
if hasattr(p,val):
|
||||
extra_params_kwargs[val] = getattr(p,val)
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue