[Support PyTorch 1.8] Remove inference mode (#707)
This commit is contained in:
parent
688031c592
commit
b35bac4d3b
|
@ -79,7 +79,7 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
|||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
|
||||
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
||||
image_embeds = self.visual_projection(pooled_output)
|
||||
|
|
Loading…
Reference in New Issue