cast() to fix type mismatch warnings

This commit is contained in:
Damian Stewart 2023-04-14 20:45:15 +02:00
parent 9b663cd23e
commit 5bf5d4026d
1 changed files with 3 additions and 3 deletions

View File

@ -3,7 +3,7 @@ import logging
import os.path
from dataclasses import dataclass
import random
from typing import Generator, Callable, Any
from typing import Generator, Callable, Any, cast
import torch
from PIL import Image, ImageDraw, ImageFont
@ -218,8 +218,8 @@ class SampleGenerator:
for cfg in self.cfgs:
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
prompt_embeds = FloatTensor(compel(prompts))
negative_prompt_embeds: FloatTensor = FloatTensor(compel(negative_prompts))
prompt_embeds = cast(compel(prompts), FloatTensor)
negative_prompt_embeds = cast(compel(negative_prompts), FloatTensor)
images = pipe(prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_inference_steps=self.num_inference_steps,