cast() to fix type mismatch warnings
This commit is contained in:
parent
9b663cd23e
commit
5bf5d4026d
|
@ -3,7 +3,7 @@ import logging
|
|||
import os.path
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
from typing import Generator, Callable, Any
|
||||
from typing import Generator, Callable, Any, cast
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
@ -218,8 +218,8 @@ class SampleGenerator:
|
|||
for cfg in self.cfgs:
|
||||
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
|
||||
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
|
||||
prompt_embeds = FloatTensor(compel(prompts))
|
||||
negative_prompt_embeds: FloatTensor = FloatTensor(compel(negative_prompts))
|
||||
prompt_embeds = cast(compel(prompts), FloatTensor)
|
||||
negative_prompt_embeds = cast(compel(negative_prompts), FloatTensor)
|
||||
images = pipe(prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
|
|
Loading…
Reference in New Issue