parent
6a5f572374
commit
d820e1fecf
|
@ -8,10 +8,10 @@ License: Apache 2.0
|
||||||
# ruff: noqa
|
# ruff: noqa
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import typing as T
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
|
@ -126,7 +126,7 @@ def parse_prompt_attention(text):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
|
||||||
r"""
|
r"""
|
||||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||||
No padding, starting or ending token is included.
|
No padding, starting or ending token is included.
|
||||||
|
@ -195,7 +195,7 @@ def get_unweighted_text_embeddings(
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
text_input: torch.Tensor,
|
text_input: torch.Tensor,
|
||||||
chunk_length: int,
|
chunk_length: int,
|
||||||
no_boseos_middle: Optional[bool] = True,
|
no_boseos_middle: T.Optional[bool] = True,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||||
|
@ -235,12 +235,12 @@ def get_unweighted_text_embeddings(
|
||||||
|
|
||||||
def get_weighted_text_embeddings(
|
def get_weighted_text_embeddings(
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
prompt: Union[str, List[str]],
|
prompt: T.Union[str, T.List[str]],
|
||||||
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
|
||||||
max_embeddings_multiples: Optional[int] = 3,
|
max_embeddings_multiples: T.Optional[int] = 3,
|
||||||
no_boseos_middle: Optional[bool] = False,
|
no_boseos_middle: T.Optional[bool] = False,
|
||||||
skip_parsing: Optional[bool] = False,
|
skip_parsing: T.Optional[bool] = False,
|
||||||
skip_weighting: Optional[bool] = False,
|
skip_weighting: T.Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
|
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -251,9 +251,9 @@ def get_weighted_text_embeddings(
|
||||||
Args:
|
Args:
|
||||||
pipe (`StableDiffusionPipeline`):
|
pipe (`StableDiffusionPipeline`):
|
||||||
Pipe to provide access to the tokenizer and the text encoder.
|
Pipe to provide access to the tokenizer and the text encoder.
|
||||||
prompt (`str` or `List[str]`):
|
prompt (`str` or `T.List[str]`):
|
||||||
The prompt or prompts to guide the image generation.
|
The prompt or prompts to guide the image generation.
|
||||||
uncond_prompt (`str` or `List[str]`):
|
uncond_prompt (`str` or `T.List[str]`):
|
||||||
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
||||||
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
||||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||||
|
|
|
@ -28,5 +28,5 @@ class PrintExifTest(TestCase):
|
||||||
print_exif(image=str(image_path))
|
print_exif(image=str(image_path))
|
||||||
|
|
||||||
# Check that a couple of values are printed
|
# Check that a couple of values are printed
|
||||||
self.assertTrue("NUM_FREQUENCIES: 512" in stdout.getvalue())
|
self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
|
||||||
self.assertTrue("SAMPLE_RATE: 44100" in stdout.getvalue())
|
self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())
|
||||||
|
|
Loading…
Reference in New Issue