Support masks

This commit is contained in:
Hayk Martiros 2022-11-26 06:48:52 +00:00
parent 51034da7bb
commit 6e69c94277
7 changed files with 68 additions and 10 deletions

View File

@ -28,7 +28,7 @@ Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-infere
{
alpha: 0.75,
num_inference_steps: 50,
seed_image_id: 0,
seed_image_id: "og_beat",
start: {
prompt: "church bells on sunday",
@ -53,4 +53,3 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
audio: "< base64 encoded MP3 clip >",,
}
```

View File

@ -1,4 +1,5 @@
black
ipdb
isort
mypy
pylint

View File

@ -3,6 +3,7 @@ Data model for the riffusion API.
"""
from dataclasses import dataclass
import typing as T
@dataclass
@ -46,8 +47,10 @@ class InferenceInput:
num_inference_steps: int = 50
# Which seed image to use
# TODO(hayk): Convert this to a string ID and add a seed image + mask API.
seed_image_id: int = 0
seed_image_id: str = "og_beat"
# ID of mask image to use
mask_image_id: T.Optional[str] = None
@dataclass

View File

@ -78,9 +78,17 @@ class RiffusionPipeline(DiffusionPipeline):
self,
inputs: InferenceInput,
init_image: PIL.Image.Image,
mask_image: PIL.Image.Image = None,
) -> PIL.Image.Image:
"""
Runs inference using interpolation with both img2img and text conditioning.
Args:
inputs: Parameter dataclass
init_image: Image used for conditioning
mask_image: White pixels in the mask will be replaced by noise and therefore repainted,
while black pixels will be preserved. It will be converted to a single
channel (luminance) before use.
"""
alpha = inputs.alpha
start = inputs.start
@ -96,7 +104,7 @@ class RiffusionPipeline(DiffusionPipeline):
text_embedding = torch.lerp(embed_start, embed_end, alpha)
# Image latents
init_image = preprocess(init_image)
init_image = preprocess_image(init_image)
init_image_torch = init_image.to(device=self.device, dtype=embed_start.dtype)
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
@ -105,9 +113,18 @@ class RiffusionPipeline(DiffusionPipeline):
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Prepare mask latent
if mask_image:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
mask_image = preprocess_mask(mask_image, scale_factor=vae_scale_factor)
mask = mask_image.to(device=self.device, dtype=embed_start.dtype)
else:
mask = None
outputs = self.interpolate_img2img(
text_embeddings=text_embedding,
init_latents=init_latents,
mask=mask,
generator_a=generator_start,
generator_b=generator_end,
interpolate_alpha=alpha,
@ -124,9 +141,10 @@ class RiffusionPipeline(DiffusionPipeline):
self,
text_embeddings: torch.FloatTensor,
init_latents: torch.FloatTensor,
generator_a: T.Optional[torch.Generator],
generator_b: T.Optional[torch.Generator],
generator_a: torch.Generator,
generator_b: torch.Generator,
interpolate_alpha: float,
mask: T.Optional[torch.FloatTensor] = None,
strength_a: float = 0.8,
strength_b: float = 0.8,
num_inference_steps: T.Optional[int] = 50,
@ -137,6 +155,9 @@ class RiffusionPipeline(DiffusionPipeline):
output_type: T.Optional[str] = "pil",
**kwargs,
):
"""
TODO
"""
batch_size = text_embeddings.shape[0]
# set timesteps
@ -209,6 +230,7 @@ class RiffusionPipeline(DiffusionPipeline):
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
)
noise = slerp(interpolate_alpha, noise_a, noise_b)
init_latents_orig = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@ -220,7 +242,7 @@ class RiffusionPipeline(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = init_latents
latents = init_latents.clone()
t_start = max(num_inference_steps - init_timestep + offset, 0)
@ -250,6 +272,11 @@ class RiffusionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if mask is not None:
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
# import ipdb; ipdb.set_trace()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
latents = 1.0 / 0.18215 * latents
image = self.vae.decode(latents).sample
@ -262,7 +289,7 @@ class RiffusionPipeline(DiffusionPipeline):
return dict(images=image, latents=latents, nsfw_content_detected=False)
def preprocess(image):
def preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
"""
Preprocess an image for the model.
"""
@ -275,6 +302,25 @@ def preprocess(image):
return 2.0 * image - 1.0
def preprocess_mask(mask: PIL.Image.Image, scale_factor: int = 8) -> torch.Tensor:
"""
Preprocess a mask for the model.
"""
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize(
(w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST
)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
def slerp(t, v0, v1, dot_threshold=0.9995):
"""
Helper function to spherically interpolate two arrays v1 v2.

View File

@ -113,8 +113,17 @@ def run_inference():
return f"Invalid seed image: {inputs.seed_image_id}", 400
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# Load the mask image by ID
if inputs.mask_image_id:
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
if not mask_image_path.is_file:
return f"Invalid mask image: {inputs.mask_image_id}", 400
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
else:
mask_image = None
# Execute the model to get the spectrogram image
image = MODEL.riffuse(inputs, init_image=init_image)
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
# Reconstruct audio from the image
wav_bytes = wav_bytes_from_spectrogram_image(image)

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

View File

Before

Width:  |  Height:  |  Size: 108 KiB

After

Width:  |  Height:  |  Size: 108 KiB