diff --git a/riffusion/external/prompt_weighting.py b/riffusion/external/prompt_weighting.py index 5c71486..160b7a9 100644 --- a/riffusion/external/prompt_weighting.py +++ b/riffusion/external/prompt_weighting.py @@ -8,10 +8,10 @@ License: Apache 2.0 # ruff: noqa # mypy: ignore-errors -import re -from typing import List, Optional, Union - import logging +import re +import typing as T + import torch from diffusers import StableDiffusionPipeline @@ -126,7 +126,7 @@ def parse_prompt_attention(text): return res -def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): +def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. @@ -195,7 +195,7 @@ def get_unweighted_text_embeddings( pipe: StableDiffusionPipeline, text_input: torch.Tensor, chunk_length: int, - no_boseos_middle: Optional[bool] = True, + no_boseos_middle: T.Optional[bool] = True, ) -> torch.FloatTensor: """ When the length of tokens is a multiple of the capacity of the text encoder, @@ -235,12 +235,12 @@ def get_unweighted_text_embeddings( def get_weighted_text_embeddings( pipe: StableDiffusionPipeline, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 3, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, + prompt: T.Union[str, T.List[str]], + uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None, + max_embeddings_multiples: T.Optional[int] = 3, + no_boseos_middle: T.Optional[bool] = False, + skip_parsing: T.Optional[bool] = False, + skip_weighting: T.Optional[bool] = False, **kwargs, ) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]: r""" @@ -251,9 +251,9 @@ def get_weighted_text_embeddings( Args: pipe (`StableDiffusionPipeline`): Pipe to provide access to the tokenizer and the text encoder. - prompt (`str` or `List[str]`): + prompt (`str` or `T.List[str]`): The prompt or prompts to guide the image generation. - uncond_prompt (`str` or `List[str]`): + uncond_prompt (`str` or `T.List[str]`): The unconditional prompt or prompts for guide the image generation. If unconditional prompt is provided, the embeddings of prompt and uncond_prompt are concatenated. max_embeddings_multiples (`int`, *optional*, defaults to `3`): diff --git a/test/print_exif_test.py b/test/print_exif_test.py index 82b6d94..ad32b00 100644 --- a/test/print_exif_test.py +++ b/test/print_exif_test.py @@ -28,5 +28,5 @@ class PrintExifTest(TestCase): print_exif(image=str(image_path)) # Check that a couple of values are printed - self.assertTrue("NUM_FREQUENCIES: 512" in stdout.getvalue()) - self.assertTrue("SAMPLE_RATE: 44100" in stdout.getvalue()) + self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue()) + self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())