parent
6a5f572374
commit
d820e1fecf
|
@ -8,10 +8,10 @@ License: Apache 2.0
|
|||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as T
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
@ -126,7 +126,7 @@ def parse_prompt_attention(text):
|
|||
return res
|
||||
|
||||
|
||||
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
||||
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||
No padding, starting or ending token is included.
|
||||
|
@ -195,7 +195,7 @@ def get_unweighted_text_embeddings(
|
|||
pipe: StableDiffusionPipeline,
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
no_boseos_middle: T.Optional[bool] = True,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
|
@ -235,12 +235,12 @@ def get_unweighted_text_embeddings(
|
|||
|
||||
def get_weighted_text_embeddings(
|
||||
pipe: StableDiffusionPipeline,
|
||||
prompt: Union[str, List[str]],
|
||||
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
||||
max_embeddings_multiples: Optional[int] = 3,
|
||||
no_boseos_middle: Optional[bool] = False,
|
||||
skip_parsing: Optional[bool] = False,
|
||||
skip_weighting: Optional[bool] = False,
|
||||
prompt: T.Union[str, T.List[str]],
|
||||
uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
|
||||
max_embeddings_multiples: T.Optional[int] = 3,
|
||||
no_boseos_middle: T.Optional[bool] = False,
|
||||
skip_parsing: T.Optional[bool] = False,
|
||||
skip_weighting: T.Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
|
||||
r"""
|
||||
|
@ -251,9 +251,9 @@ def get_weighted_text_embeddings(
|
|||
Args:
|
||||
pipe (`StableDiffusionPipeline`):
|
||||
Pipe to provide access to the tokenizer and the text encoder.
|
||||
prompt (`str` or `List[str]`):
|
||||
prompt (`str` or `T.List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
uncond_prompt (`str` or `List[str]`):
|
||||
uncond_prompt (`str` or `T.List[str]`):
|
||||
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
||||
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
|
|
|
@ -28,5 +28,5 @@ class PrintExifTest(TestCase):
|
|||
print_exif(image=str(image_path))
|
||||
|
||||
# Check that a couple of values are printed
|
||||
self.assertTrue("NUM_FREQUENCIES: 512" in stdout.getvalue())
|
||||
self.assertTrue("SAMPLE_RATE: 44100" in stdout.getvalue())
|
||||
self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
|
||||
self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())
|
||||
|
|
Loading…
Reference in New Issue