Fix minor errors

Topic: clean_rewrite
This commit is contained in:
Hayk Martiros 2022-12-26 17:39:55 -08:00
parent 6a5f572374
commit d820e1fecf
2 changed files with 15 additions and 15 deletions

View File

@ -8,10 +8,10 @@ License: Apache 2.0
# ruff: noqa # ruff: noqa
# mypy: ignore-errors # mypy: ignore-errors
import re
from typing import List, Optional, Union
import logging import logging
import re
import typing as T
import torch import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
@ -126,7 +126,7 @@ def parse_prompt_attention(text):
return res return res
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
r""" r"""
Tokenize a list of prompts and return its tokens with weights of each token. Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included. No padding, starting or ending token is included.
@ -195,7 +195,7 @@ def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
text_input: torch.Tensor, text_input: torch.Tensor,
chunk_length: int, chunk_length: int,
no_boseos_middle: Optional[bool] = True, no_boseos_middle: T.Optional[bool] = True,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
When the length of tokens is a multiple of the capacity of the text encoder, When the length of tokens is a multiple of the capacity of the text encoder,
@ -235,12 +235,12 @@ def get_unweighted_text_embeddings(
def get_weighted_text_embeddings( def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
prompt: Union[str, List[str]], prompt: T.Union[str, T.List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None, uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3, max_embeddings_multiples: T.Optional[int] = 3,
no_boseos_middle: Optional[bool] = False, no_boseos_middle: T.Optional[bool] = False,
skip_parsing: Optional[bool] = False, skip_parsing: T.Optional[bool] = False,
skip_weighting: Optional[bool] = False, skip_weighting: T.Optional[bool] = False,
**kwargs, **kwargs,
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]: ) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
r""" r"""
@ -251,9 +251,9 @@ def get_weighted_text_embeddings(
Args: Args:
pipe (`StableDiffusionPipeline`): pipe (`StableDiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder. Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`): prompt (`str` or `T.List[str]`):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`): uncond_prompt (`str` or `T.List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated. is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`): max_embeddings_multiples (`int`, *optional*, defaults to `3`):

View File

@ -28,5 +28,5 @@ class PrintExifTest(TestCase):
print_exif(image=str(image_path)) print_exif(image=str(image_path))
# Check that a couple of values are printed # Check that a couple of values are printed
self.assertTrue("NUM_FREQUENCIES: 512" in stdout.getvalue()) self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
self.assertTrue("SAMPLE_RATE: 44100" in stdout.getvalue()) self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())