Support masks
This commit is contained in:
parent
51034da7bb
commit
6e69c94277
|
@ -28,7 +28,7 @@ Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-infere
|
|||
{
|
||||
alpha: 0.75,
|
||||
num_inference_steps: 50,
|
||||
seed_image_id: 0,
|
||||
seed_image_id: "og_beat",
|
||||
|
||||
start: {
|
||||
prompt: "church bells on sunday",
|
||||
|
@ -53,4 +53,3 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
|
|||
audio: "< base64 encoded MP3 clip >",,
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
black
|
||||
ipdb
|
||||
isort
|
||||
mypy
|
||||
pylint
|
||||
|
|
|
@ -3,6 +3,7 @@ Data model for the riffusion API.
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import typing as T
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -46,8 +47,10 @@ class InferenceInput:
|
|||
num_inference_steps: int = 50
|
||||
|
||||
# Which seed image to use
|
||||
# TODO(hayk): Convert this to a string ID and add a seed image + mask API.
|
||||
seed_image_id: int = 0
|
||||
seed_image_id: str = "og_beat"
|
||||
|
||||
# ID of mask image to use
|
||||
mask_image_id: T.Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -78,9 +78,17 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
self,
|
||||
inputs: InferenceInput,
|
||||
init_image: PIL.Image.Image,
|
||||
mask_image: PIL.Image.Image = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Runs inference using interpolation with both img2img and text conditioning.
|
||||
|
||||
Args:
|
||||
inputs: Parameter dataclass
|
||||
init_image: Image used for conditioning
|
||||
mask_image: White pixels in the mask will be replaced by noise and therefore repainted,
|
||||
while black pixels will be preserved. It will be converted to a single
|
||||
channel (luminance) before use.
|
||||
"""
|
||||
alpha = inputs.alpha
|
||||
start = inputs.start
|
||||
|
@ -96,7 +104,7 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
text_embedding = torch.lerp(embed_start, embed_end, alpha)
|
||||
|
||||
# Image latents
|
||||
init_image = preprocess(init_image)
|
||||
init_image = preprocess_image(init_image)
|
||||
init_image_torch = init_image.to(device=self.device, dtype=embed_start.dtype)
|
||||
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
|
||||
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
|
||||
|
@ -105,9 +113,18 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Prepare mask latent
|
||||
if mask_image:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
mask_image = preprocess_mask(mask_image, scale_factor=vae_scale_factor)
|
||||
mask = mask_image.to(device=self.device, dtype=embed_start.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
outputs = self.interpolate_img2img(
|
||||
text_embeddings=text_embedding,
|
||||
init_latents=init_latents,
|
||||
mask=mask,
|
||||
generator_a=generator_start,
|
||||
generator_b=generator_end,
|
||||
interpolate_alpha=alpha,
|
||||
|
@ -124,9 +141,10 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
self,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
init_latents: torch.FloatTensor,
|
||||
generator_a: T.Optional[torch.Generator],
|
||||
generator_b: T.Optional[torch.Generator],
|
||||
generator_a: torch.Generator,
|
||||
generator_b: torch.Generator,
|
||||
interpolate_alpha: float,
|
||||
mask: T.Optional[torch.FloatTensor] = None,
|
||||
strength_a: float = 0.8,
|
||||
strength_b: float = 0.8,
|
||||
num_inference_steps: T.Optional[int] = 50,
|
||||
|
@ -137,6 +155,9 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
output_type: T.Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
batch_size = text_embeddings.shape[0]
|
||||
|
||||
# set timesteps
|
||||
|
@ -209,6 +230,7 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
|
||||
)
|
||||
noise = slerp(interpolate_alpha, noise_a, noise_b)
|
||||
init_latents_orig = init_latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
|
@ -220,7 +242,7 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
latents = init_latents
|
||||
latents = init_latents.clone()
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
|
@ -250,6 +272,11 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if mask is not None:
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
# import ipdb; ipdb.set_trace()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
latents = 1.0 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
|
@ -262,7 +289,7 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
return dict(images=image, latents=latents, nsfw_content_detected=False)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
def preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess an image for the model.
|
||||
"""
|
||||
|
@ -275,6 +302,25 @@ def preprocess(image):
|
|||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask: PIL.Image.Image, scale_factor: int = 8) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess a mask for the model.
|
||||
"""
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize(
|
||||
(w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST
|
||||
)
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
mask = 1 - mask # repaint white, keep black
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def slerp(t, v0, v1, dot_threshold=0.9995):
|
||||
"""
|
||||
Helper function to spherically interpolate two arrays v1 v2.
|
||||
|
|
|
@ -113,8 +113,17 @@ def run_inference():
|
|||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Load the mask image by ID
|
||||
if inputs.mask_image_id:
|
||||
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
|
||||
if not mask_image_path.is_file:
|
||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||
else:
|
||||
mask_image = None
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = MODEL.riffuse(inputs, init_image=init_image)
|
||||
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
wav_bytes = wav_bytes_from_spectrogram_image(image)
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 7.1 KiB |
Before Width: | Height: | Size: 108 KiB After Width: | Height: | Size: 108 KiB |
Loading…
Reference in New Issue