From b35bac4d3b1af7e2389809f96e8ada11da6cc503 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Oct 2022 22:14:58 +0200 Subject: [PATCH] [Support PyTorch 1.8] Remove inference mode (#707) --- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 3eb8828c..773a7d4b 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -79,7 +79,7 @@ class StableDiffusionSafetyChecker(PreTrainedModel): return images, has_nsfw_concepts - @torch.inference_mode() + @torch.no_grad() def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output)