[Docs] Weight prompting using compel (#2574)
* add docs * correct * finish * Apply suggestions from code review Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> * update deps table * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
f0b661b8fb
commit
22a31760c4
|
@ -48,6 +48,8 @@
|
|||
title: How to contribute a Pipeline
|
||||
- local: using-diffusers/using_safetensors
|
||||
title: Using safetensors
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Weighting Prompts
|
||||
title: Pipelines for Inference
|
||||
- sections:
|
||||
- local: using-diffusers/rl
|
||||
|
|
|
@ -36,6 +36,7 @@ Unless otherwise mentioned, these are techniques that work with existing models
|
|||
8. [DreamBooth](#dreambooth)
|
||||
9. [Textual Inversion](#textual-inversion)
|
||||
10. [ControlNet](#controlnet)
|
||||
11. [Prompt Weighting](#prompt-weighting)
|
||||
|
||||
## Instruct Pix2Pix
|
||||
|
||||
|
@ -158,3 +159,9 @@ depth maps, and semantic segmentations.
|
|||
|
||||
See [here](../api/pipelines/stable_diffusion/controlnet) for more information on how to use it.
|
||||
|
||||
## Prompt Weighting
|
||||
|
||||
Prompt weighting is a simple technique that puts more attention weight on certain parts of the text
|
||||
input.
|
||||
|
||||
For a more in-detail explanation and examples, see [here](../using-diffusers/weighted_prompts).
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Weighting prompts
|
||||
|
||||
Text-guided diffusion models generate images based on a given text prompt. The text prompt
|
||||
can include multiple concepts that the model should generate and it's often desirable to weight
|
||||
certain parts of the prompt more or less.
|
||||
|
||||
Diffusion models work by conditioning the cross attention layers of the diffusion model with contextualized text embeddings (see the [Stable Diffusion Guide for more information](../stable-diffusion)).
|
||||
Thus a simple way to emphasize (or de-emphasize) certain parts of the prompt is by increasing or reducing the scale of the text embedding vector that corresponds to the relevant part of the prompt.
|
||||
This is called "prompt-weighting" and has been a highly demanded feature by the community (see issue [here](https://github.com/huggingface/diffusers/issues/2431)).
|
||||
|
||||
## How to do prompt-weighting in Diffusers
|
||||
|
||||
We believe the role of `diffusers` is to be a toolbox that provides essential features that enable other projects, such as [InvokeAI](https://github.com/invoke-ai/InvokeAI) or [diffuzers](https://github.com/abhishekkrthakur/diffuzers), to build powerful UIs. In order to support arbitrary methods to manipulate prompts, `diffusers` exposes a [`prompt_embeds`](https://huggingface.co/docs/diffusers/v0.14.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) function argument to many pipelines such as [`StableDiffusionPipeline`], allowing to directly pass the "prompt-weighted"/scaled text embeddings to the pipeline.
|
||||
|
||||
The [compel library](https://github.com/damian0815/compel) provides an easy way to emphasize or de-emphasize portions of the prompt for you. We strongly recommend it instead of preparing the embeddings yourself.
|
||||
|
||||
Let's look at a simple example. Imagine you want to generate an image of `"a red cat playing with a ball"` as
|
||||
follows:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
prompt = "a red cat playing with a ball"
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
This gives you:
|
||||
|
||||
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_0.png)
|
||||
|
||||
As you can see, there is no "ball" in the image. Let's emphasize this part!
|
||||
|
||||
For this we should install the `compel` library:
|
||||
|
||||
```
|
||||
pip install compel
|
||||
```
|
||||
|
||||
and then create a `Compel` object:
|
||||
|
||||
```py
|
||||
from compel import Compel
|
||||
|
||||
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
|
||||
```
|
||||
|
||||
Now we emphasize the part "ball" with the `"++"` syntax:
|
||||
|
||||
```py
|
||||
prompt = "a red cat playing with a ball++"
|
||||
```
|
||||
|
||||
and instead of passing this to the pipeline directly, we have to process it using `compel_proc`:
|
||||
|
||||
```py
|
||||
prompt_embeds = compel_proc(prompt)
|
||||
```
|
||||
|
||||
Now we can pass `prompt_embeds` directly to the pipeline:
|
||||
|
||||
```py
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
|
||||
images = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
We now get the following image which has a "ball"!
|
||||
|
||||
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_1.png)
|
||||
|
||||
Similarly, we de-emphasize parts of the sentence by using the `--` suffix for words, feel free to give it
|
||||
a try!
|
||||
|
||||
If your favorite pipeline does not have a `prompt_embeds` input, please make sure to open an issue, the
|
||||
diffusers team tries to be as responsive as possible.
|
||||
|
||||
Also, please check out the documentation of the [compel](https://github.com/damian0815/compel) library for
|
||||
more information.
|
2
setup.py
2
setup.py
|
@ -80,6 +80,7 @@ from setuptools import find_packages, setup
|
|||
_deps = [
|
||||
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
||||
"accelerate>=0.11.0",
|
||||
"compel==0.1.8",
|
||||
"black~=23.1",
|
||||
"datasets",
|
||||
"filelock",
|
||||
|
@ -182,6 +183,7 @@ extras["quality"] = deps_list("black", "isort", "ruff", "hf-doc-builder")
|
|||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
|
||||
extras["test"] = deps_list(
|
||||
"compel",
|
||||
"datasets",
|
||||
"Jinja2",
|
||||
"k-diffusion",
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.11.0",
|
||||
"compel": "compel==0.1.8",
|
||||
"black": "black~=23.1",
|
||||
"datasets": "datasets",
|
||||
"filelock": "filelock",
|
||||
|
|
|
@ -232,6 +232,14 @@ except importlib_metadata.PackageNotFoundError:
|
|||
_tensorboard_available = False
|
||||
|
||||
|
||||
_compel_available = importlib.util.find_spec("compel")
|
||||
try:
|
||||
_compel_version = importlib_metadata.version("compel")
|
||||
logger.debug(f"Successfully imported compel version {_compel_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_compel_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
|
@ -296,6 +304,10 @@ def is_tensorboard_available():
|
|||
return _tensorboard_available
|
||||
|
||||
|
||||
def is_compel_available():
|
||||
return _compel_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
@ -368,6 +380,12 @@ TENSORBOARD_IMPORT_ERROR = """
|
|||
install tensorboard`
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
COMPEL_IMPORT_ERROR = """
|
||||
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
|
@ -382,6 +400,7 @@ BACKENDS_MAPPING = OrderedDict(
|
|||
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
|
||||
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
|
||||
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
|
||||
("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ import PIL.ImageOps
|
|||
import requests
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_flax_available, is_onnx_available, is_torch_available
|
||||
from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
|
@ -175,6 +175,14 @@ def require_flax(test_case):
|
|||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||
|
||||
|
||||
def require_compel(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
|
||||
the library is not installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
|
||||
|
||||
|
||||
def require_onnxruntime(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
|
||||
|
|
|
@ -49,11 +49,12 @@ from diffusers import (
|
|||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UniPCMultistepScheduler,
|
||||
logging,
|
||||
)
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
@ -1058,6 +1059,37 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
|
||||
assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@require_compel
|
||||
def test_weighted_prompts_compel(self):
|
||||
from compel import Compel
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
|
||||
|
||||
prompt = "a red cat playing with a ball{}"
|
||||
|
||||
prompts = [prompt.format(s) for s in ["", "++", "--"]]
|
||||
|
||||
prompt_embeds = compel(prompts)
|
||||
|
||||
generator = [torch.Generator(device="cpu").manual_seed(33) for _ in range(prompt_embeds.shape[0])]
|
||||
|
||||
images = pipe(
|
||||
prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20, output_type="numpy"
|
||||
).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
f"/compel/forest_{i}.npy"
|
||||
)
|
||||
|
||||
assert np.abs(image - expected_image).max() < 1e-3
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
|
Loading…
Reference in New Issue