Fix timestep dtype in legacy inpaint (#2120)
* Fix timestep dtype in legacy inpaint This matches the structure in the text2img, img2img, and inpaint ONNX pipelines * Fix style in dtype patch
This commit is contained in:
parent
a87e87fcbe
commit
7547f9b475
|
@ -10,7 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
|||
from ...configuration_utils import FrozenDict
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import deprecate, logging
|
||||
from ..onnx_utils import OnnxRuntimeModel
|
||||
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
|
@ -391,6 +391,10 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:].numpy()
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
|
@ -398,9 +402,10 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=prompt_embeds
|
||||
)[0]
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
|
||||
0
|
||||
]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
|
Loading…
Reference in New Issue