159 lines
4.8 KiB
Python
159 lines
4.8 KiB
Python
"""
|
|
Prediction interface for Cog ⚙️
|
|
https://github.com/replicate/cog/blob/main/docs/python.md
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import shutil
|
|
import typing as T
|
|
|
|
import numpy as np
|
|
import PIL
|
|
import torch
|
|
from cog import BaseModel, BasePredictor, Input, Path
|
|
|
|
from riffusion.datatypes import InferenceInput, PromptInput
|
|
from riffusion.riffusion_pipeline import RiffusionPipeline
|
|
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
|
from riffusion.spectrogram_params import SpectrogramParams
|
|
|
|
MODEL_ID = "riffusion/riffusion-model-v1"
|
|
MODEL_CACHE = "riffusion-cache"
|
|
|
|
# Where built-in seed images are stored
|
|
SEED_IMAGES_DIR = Path("./seed_images")
|
|
SEED_IMAGES = [val.split(".")[0] for val in os.listdir(SEED_IMAGES_DIR) if "png" in val]
|
|
SEED_IMAGES.sort()
|
|
|
|
|
|
class Output(BaseModel):
|
|
"""
|
|
Output class for riffusion predictions
|
|
"""
|
|
|
|
audio: Path
|
|
spectrogram: Path
|
|
error: T.Optional[str] = None
|
|
|
|
|
|
class RiffusionPredictor(BasePredictor):
|
|
"""
|
|
Implementation of cog predictor object s.t. we can run riffusion predictions w/cog.
|
|
|
|
See README & https://github.com/replicate/cog for details
|
|
"""
|
|
|
|
def setup(self, local_files_only=True):
|
|
"""
|
|
Loads the model onto GPU from local cache.
|
|
"""
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
self.model = RiffusionPipeline.load_checkpoint(
|
|
checkpoint=MODEL_ID,
|
|
use_traced_unet=True,
|
|
device=self.device,
|
|
local_files_only=local_files_only,
|
|
cache_dir=MODEL_CACHE,
|
|
)
|
|
|
|
def predict(
|
|
self,
|
|
prompt_a: str = Input(description="The prompt for your audio", default="funky synth solo"),
|
|
denoising: float = Input(
|
|
description="How much to transform input spectrogram",
|
|
default=0.75,
|
|
ge=0,
|
|
le=1,
|
|
),
|
|
prompt_b: str = Input(
|
|
description="The second prompt to interpolate with the first,"
|
|
"leave blank if no interpolation",
|
|
default=None,
|
|
),
|
|
alpha: float = Input(
|
|
description="Interpolation alpha if using two prompts."
|
|
"A value of 0 uses prompt_a fully, a value of 1 uses prompt_b fully",
|
|
default=0.5,
|
|
ge=0,
|
|
le=1,
|
|
),
|
|
num_inference_steps: int = Input(
|
|
description="Number of steps to run the diffusion model", default=50, ge=1
|
|
),
|
|
seed_image_id: str = Input(
|
|
description="Seed spectrogram to use", default="vibes", choices=SEED_IMAGES
|
|
),
|
|
) -> Output:
|
|
"""
|
|
Runs riffusion inference.
|
|
"""
|
|
# Load the seed image by ID
|
|
init_image_path = Path(SEED_IMAGES_DIR, f"{seed_image_id}.png")
|
|
if not init_image_path.is_file():
|
|
return Output(error=f"Invalid seed image: {seed_image_id}")
|
|
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
|
|
|
# fake max ints
|
|
seed_a = np.random.randint(0, 2147483647)
|
|
seed_b = np.random.randint(0, 2147483647)
|
|
|
|
start = PromptInput(prompt=prompt_a, seed=seed_a, denoising=denoising)
|
|
if not prompt_b: # no transition
|
|
prompt_b = prompt_a
|
|
alpha = 0
|
|
end = PromptInput(prompt=prompt_b, seed=seed_b, denoising=denoising)
|
|
riffusion_input = InferenceInput(
|
|
start=start,
|
|
end=end,
|
|
alpha=alpha,
|
|
num_inference_steps=num_inference_steps,
|
|
seed_image_id=seed_image_id,
|
|
)
|
|
|
|
# Execute the model to get the spectrogram image
|
|
image = self.model.riffuse(riffusion_input, init_image=init_image, mask_image=None)
|
|
|
|
# Reconstruct audio from the image
|
|
params = SpectrogramParams()
|
|
converter = SpectrogramImageConverter(params=params, device=self.device)
|
|
segment = converter.audio_from_spectrogram_image(image)
|
|
|
|
if not os.path.exists("out/"):
|
|
os.mkdir("out")
|
|
|
|
out_img_path = "out/spectrogram.jpg"
|
|
image.save("out/spectrogram.jpg", exif=image.getexif())
|
|
|
|
out_wav_path = "out/gen_sound.wav"
|
|
segment.export(out_wav_path, format="wav")
|
|
|
|
return Output(audio=Path(out_wav_path), spectrogram=Path(out_img_path))
|
|
|
|
|
|
# TODO(hayk): Can we get rid of the below functions and incorporate into
|
|
# RiffusionPipeline.load_checkpoint?
|
|
|
|
|
|
def download_weights():
|
|
"""
|
|
Clears local cache & downloads riffusion weights
|
|
"""
|
|
if os.path.exists(MODEL_CACHE):
|
|
shutil.rmtree(MODEL_CACHE)
|
|
os.makedirs(MODEL_CACHE)
|
|
|
|
pred = RiffusionPredictor()
|
|
pred.setup(local_files_only=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--download_weights", action="store_true", help="Download and cache weights"
|
|
)
|
|
args = parser.parse_args()
|
|
if args.download_weights:
|
|
download_weights()
|