[Docs] Weight prompting using compel (#2574)
* add docs * correct * finish * Apply suggestions from code review Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> * update deps table * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
f0b661b8fb
commit
22a31760c4
|
@ -48,6 +48,8 @@
|
||||||
title: How to contribute a Pipeline
|
title: How to contribute a Pipeline
|
||||||
- local: using-diffusers/using_safetensors
|
- local: using-diffusers/using_safetensors
|
||||||
title: Using safetensors
|
title: Using safetensors
|
||||||
|
- local: using-diffusers/weighted_prompts
|
||||||
|
title: Weighting Prompts
|
||||||
title: Pipelines for Inference
|
title: Pipelines for Inference
|
||||||
- sections:
|
- sections:
|
||||||
- local: using-diffusers/rl
|
- local: using-diffusers/rl
|
||||||
|
|
|
@ -36,6 +36,7 @@ Unless otherwise mentioned, these are techniques that work with existing models
|
||||||
8. [DreamBooth](#dreambooth)
|
8. [DreamBooth](#dreambooth)
|
||||||
9. [Textual Inversion](#textual-inversion)
|
9. [Textual Inversion](#textual-inversion)
|
||||||
10. [ControlNet](#controlnet)
|
10. [ControlNet](#controlnet)
|
||||||
|
11. [Prompt Weighting](#prompt-weighting)
|
||||||
|
|
||||||
## Instruct Pix2Pix
|
## Instruct Pix2Pix
|
||||||
|
|
||||||
|
@ -158,3 +159,9 @@ depth maps, and semantic segmentations.
|
||||||
|
|
||||||
See [here](../api/pipelines/stable_diffusion/controlnet) for more information on how to use it.
|
See [here](../api/pipelines/stable_diffusion/controlnet) for more information on how to use it.
|
||||||
|
|
||||||
|
## Prompt Weighting
|
||||||
|
|
||||||
|
Prompt weighting is a simple technique that puts more attention weight on certain parts of the text
|
||||||
|
input.
|
||||||
|
|
||||||
|
For a more in-detail explanation and examples, see [here](../using-diffusers/weighted_prompts).
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Weighting prompts
|
||||||
|
|
||||||
|
Text-guided diffusion models generate images based on a given text prompt. The text prompt
|
||||||
|
can include multiple concepts that the model should generate and it's often desirable to weight
|
||||||
|
certain parts of the prompt more or less.
|
||||||
|
|
||||||
|
Diffusion models work by conditioning the cross attention layers of the diffusion model with contextualized text embeddings (see the [Stable Diffusion Guide for more information](../stable-diffusion)).
|
||||||
|
Thus a simple way to emphasize (or de-emphasize) certain parts of the prompt is by increasing or reducing the scale of the text embedding vector that corresponds to the relevant part of the prompt.
|
||||||
|
This is called "prompt-weighting" and has been a highly demanded feature by the community (see issue [here](https://github.com/huggingface/diffusers/issues/2431)).
|
||||||
|
|
||||||
|
## How to do prompt-weighting in Diffusers
|
||||||
|
|
||||||
|
We believe the role of `diffusers` is to be a toolbox that provides essential features that enable other projects, such as [InvokeAI](https://github.com/invoke-ai/InvokeAI) or [diffuzers](https://github.com/abhishekkrthakur/diffuzers), to build powerful UIs. In order to support arbitrary methods to manipulate prompts, `diffusers` exposes a [`prompt_embeds`](https://huggingface.co/docs/diffusers/v0.14.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) function argument to many pipelines such as [`StableDiffusionPipeline`], allowing to directly pass the "prompt-weighted"/scaled text embeddings to the pipeline.
|
||||||
|
|
||||||
|
The [compel library](https://github.com/damian0815/compel) provides an easy way to emphasize or de-emphasize portions of the prompt for you. We strongly recommend it instead of preparing the embeddings yourself.
|
||||||
|
|
||||||
|
Let's look at a simple example. Imagine you want to generate an image of `"a red cat playing with a ball"` as
|
||||||
|
follows:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||||
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
||||||
|
prompt = "a red cat playing with a ball"
|
||||||
|
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||||
|
|
||||||
|
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||||
|
image
|
||||||
|
```
|
||||||
|
|
||||||
|
This gives you:
|
||||||
|
|
||||||
|
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_0.png)
|
||||||
|
|
||||||
|
As you can see, there is no "ball" in the image. Let's emphasize this part!
|
||||||
|
|
||||||
|
For this we should install the `compel` library:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install compel
|
||||||
|
```
|
||||||
|
|
||||||
|
and then create a `Compel` object:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from compel import Compel
|
||||||
|
|
||||||
|
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
|
||||||
|
```
|
||||||
|
|
||||||
|
Now we emphasize the part "ball" with the `"++"` syntax:
|
||||||
|
|
||||||
|
```py
|
||||||
|
prompt = "a red cat playing with a ball++"
|
||||||
|
```
|
||||||
|
|
||||||
|
and instead of passing this to the pipeline directly, we have to process it using `compel_proc`:
|
||||||
|
|
||||||
|
```py
|
||||||
|
prompt_embeds = compel_proc(prompt)
|
||||||
|
```
|
||||||
|
|
||||||
|
Now we can pass `prompt_embeds` directly to the pipeline:
|
||||||
|
|
||||||
|
```py
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||||
|
|
||||||
|
images = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||||
|
image
|
||||||
|
```
|
||||||
|
|
||||||
|
We now get the following image which has a "ball"!
|
||||||
|
|
||||||
|
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_1.png)
|
||||||
|
|
||||||
|
Similarly, we de-emphasize parts of the sentence by using the `--` suffix for words, feel free to give it
|
||||||
|
a try!
|
||||||
|
|
||||||
|
If your favorite pipeline does not have a `prompt_embeds` input, please make sure to open an issue, the
|
||||||
|
diffusers team tries to be as responsive as possible.
|
||||||
|
|
||||||
|
Also, please check out the documentation of the [compel](https://github.com/damian0815/compel) library for
|
||||||
|
more information.
|
2
setup.py
2
setup.py
|
@ -80,6 +80,7 @@ from setuptools import find_packages, setup
|
||||||
_deps = [
|
_deps = [
|
||||||
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
||||||
"accelerate>=0.11.0",
|
"accelerate>=0.11.0",
|
||||||
|
"compel==0.1.8",
|
||||||
"black~=23.1",
|
"black~=23.1",
|
||||||
"datasets",
|
"datasets",
|
||||||
"filelock",
|
"filelock",
|
||||||
|
@ -182,6 +183,7 @@ extras["quality"] = deps_list("black", "isort", "ruff", "hf-doc-builder")
|
||||||
extras["docs"] = deps_list("hf-doc-builder")
|
extras["docs"] = deps_list("hf-doc-builder")
|
||||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
|
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
|
||||||
extras["test"] = deps_list(
|
extras["test"] = deps_list(
|
||||||
|
"compel",
|
||||||
"datasets",
|
"datasets",
|
||||||
"Jinja2",
|
"Jinja2",
|
||||||
"k-diffusion",
|
"k-diffusion",
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
deps = {
|
deps = {
|
||||||
"Pillow": "Pillow",
|
"Pillow": "Pillow",
|
||||||
"accelerate": "accelerate>=0.11.0",
|
"accelerate": "accelerate>=0.11.0",
|
||||||
|
"compel": "compel==0.1.8",
|
||||||
"black": "black~=23.1",
|
"black": "black~=23.1",
|
||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
|
|
|
@ -232,6 +232,14 @@ except importlib_metadata.PackageNotFoundError:
|
||||||
_tensorboard_available = False
|
_tensorboard_available = False
|
||||||
|
|
||||||
|
|
||||||
|
_compel_available = importlib.util.find_spec("compel")
|
||||||
|
try:
|
||||||
|
_compel_version = importlib_metadata.version("compel")
|
||||||
|
logger.debug(f"Successfully imported compel version {_compel_version}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_compel_available = False
|
||||||
|
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
|
|
||||||
|
@ -296,6 +304,10 @@ def is_tensorboard_available():
|
||||||
return _tensorboard_available
|
return _tensorboard_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_compel_available():
|
||||||
|
return _compel_available
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
FLAX_IMPORT_ERROR = """
|
FLAX_IMPORT_ERROR = """
|
||||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||||
|
@ -368,6 +380,12 @@ TENSORBOARD_IMPORT_ERROR = """
|
||||||
install tensorboard`
|
install tensorboard`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# docstyle-ignore
|
||||||
|
COMPEL_IMPORT_ERROR = """
|
||||||
|
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
|
||||||
|
"""
|
||||||
|
|
||||||
BACKENDS_MAPPING = OrderedDict(
|
BACKENDS_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||||
|
@ -382,6 +400,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||||
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
|
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
|
||||||
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
|
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
|
||||||
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
|
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
|
||||||
|
("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ import PIL.ImageOps
|
||||||
import requests
|
import requests
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from .import_utils import is_flax_available, is_onnx_available, is_torch_available
|
from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,6 +175,14 @@ def require_flax(test_case):
|
||||||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_compel(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
|
||||||
|
the library is not installed.
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_onnxruntime(test_case):
|
def require_onnxruntime(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
|
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
|
||||||
|
|
|
@ -49,11 +49,12 @@ from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
UNet2DModel,
|
UNet2DModel,
|
||||||
|
UniPCMultistepScheduler,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
|
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
@ -1058,6 +1059,37 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass"
|
assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass"
|
||||||
|
|
||||||
|
@require_compel
|
||||||
|
def test_weighted_prompts_compel(self):
|
||||||
|
from compel import Compel
|
||||||
|
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||||
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
|
||||||
|
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
|
||||||
|
|
||||||
|
prompt = "a red cat playing with a ball{}"
|
||||||
|
|
||||||
|
prompts = [prompt.format(s) for s in ["", "++", "--"]]
|
||||||
|
|
||||||
|
prompt_embeds = compel(prompts)
|
||||||
|
|
||||||
|
generator = [torch.Generator(device="cpu").manual_seed(33) for _ in range(prompt_embeds.shape[0])]
|
||||||
|
|
||||||
|
images = pipe(
|
||||||
|
prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20, output_type="numpy"
|
||||||
|
).images
|
||||||
|
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
expected_image = load_numpy(
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||||
|
f"/compel/forest_{i}.npy"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.abs(image - expected_image).max() < 1e-3
|
||||||
|
|
||||||
|
|
||||||
@nightly
|
@nightly
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|
Loading…
Reference in New Issue