Fix minor errors

Topic: clean_rewrite
This commit is contained in:
Hayk Martiros 2022-12-26 17:39:55 -08:00
parent 6a5f572374
commit d820e1fecf
2 changed files with 15 additions and 15 deletions

View File

@ -8,10 +8,10 @@ License: Apache 2.0
# ruff: noqa
# mypy: ignore-errors
import re
from typing import List, Optional, Union
import logging
import re
import typing as T
import torch
from diffusers import StableDiffusionPipeline
@ -126,7 +126,7 @@ def parse_prompt_attention(text):
return res
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
@ -195,7 +195,7 @@ def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
no_boseos_middle: T.Optional[bool] = True,
) -> torch.FloatTensor:
"""
When the length of tokens is a multiple of the capacity of the text encoder,
@ -235,12 +235,12 @@ def get_unweighted_text_embeddings(
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
prompt: T.Union[str, T.List[str]],
uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
max_embeddings_multiples: T.Optional[int] = 3,
no_boseos_middle: T.Optional[bool] = False,
skip_parsing: T.Optional[bool] = False,
skip_weighting: T.Optional[bool] = False,
**kwargs,
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
r"""
@ -251,9 +251,9 @@ def get_weighted_text_embeddings(
Args:
pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
prompt (`str` or `T.List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
uncond_prompt (`str` or `T.List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):

View File

@ -28,5 +28,5 @@ class PrintExifTest(TestCase):
print_exif(image=str(image_path))
# Check that a couple of values are printed
self.assertTrue("NUM_FREQUENCIES: 512" in stdout.getvalue())
self.assertTrue("SAMPLE_RATE: 44100" in stdout.getvalue())
self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())