[Support PyTorch 1.8] Remove inference mode (#707)

This commit is contained in:
Patrick von Platen 2022-10-03 22:14:58 +02:00 committed by GitHub
parent 688031c592
commit b35bac4d3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -79,7 +79,7 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
return images, has_nsfw_concepts return images, has_nsfw_concepts
@torch.inference_mode() @torch.no_grad()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output) image_embeds = self.visual_projection(pooled_output)