Cog integration

Support hosting on Replicate using the cog interface.

Continuation of: https://github.com/riffusion/riffusion/pull/26

Topic: cog_integration
This commit is contained in:
Hayk Martiros 2022-12-29 12:21:07 -08:00
parent 4bec0a40e0
commit 45d55e986c
4 changed files with 229 additions and 2 deletions

38
cog.yaml Normal file
View File

@ -0,0 +1,38 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
build:
# set to true if your model requires a GPU
gpu: true
# a list of ubuntu apt packages to install
system_packages:
- "ffmpeg"
- "libsndfile1"
# python version in the form '3.8' or '3.8.12'
python_version: "3.9"
# a list of packages in the format <package-name>==<version>
python_packages:
- "accelerate==0.15.0"
- "argh==0.26.2"
- "dacite==1.6.0"
- "diffusers==0.10.2"
- "flask_cors==3.0.10"
- "flask==1.1.2"
- "numpy==1.19.4"
- "pillow==8.2.0"
- "pydub==0.25.1"
- "scipy==1.6.3"
- "torch==1.13.0"
- "torchaudio==0.13.0"
- "transformers==4.25.1"
# commands run after the environment is setup
# run:
# - "echo env is ready!"
# - "echo another command if needed"
# predict.py defines how predictions are run on your model
predict: "integrations/cog_riffusion.py:RiffusionPredictor"

View File

@ -0,0 +1,181 @@
"""
Prediction interface for Cog
https://github.com/replicate/cog/blob/main/docs/python.md
"""
import argparse
import os
import shutil
import typing as T
import numpy as np
import PIL
import torch
from cog import BaseModel, BasePredictor, Input, Path
from huggingface_hub import hf_hub_download
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
MODEL_ID = "riffusion/riffusion-model-v1"
MODEL_CACHE = "riffusion-cache"
UNET_CACHE = "unet-cache"
# Where built-in seed images are stored
SEED_IMAGES_DIR = Path("./seed_images")
SEED_IMAGES = [val.split(".")[0] for val in os.listdir(SEED_IMAGES_DIR) if "png" in val]
SEED_IMAGES.sort()
class Output(BaseModel):
"""
Output class for riffusion predictions
"""
audio: Path
spectrogram: Path
error: T.Optional[str] = None
class RiffusionPredictor(BasePredictor):
"""
Implementation of cog predictor object s.t. we can run riffusion predictions w/cog.
See README & https://github.com/replicate/cog for details
"""
def setup(self):
"""
Loads the model onto GPU from local cache.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = RiffusionPipeline.load_checkpoint(
checkpoint=MODEL_ID,
use_traced_unet=True,
device=self.device,
local_files_only=True,
)
def predict(
self,
prompt_a: str = Input(description="The prompt for your audio", default="funky synth solo"),
denoising: float = Input(
description="How much to transform input spectrogram",
default=0.75,
ge=0,
le=1,
),
prompt_b: str = Input(
description="The second prompt to interpolate with the first,"
"leave blank if no interpolation",
default=None,
),
alpha: float = Input(
description="Interpolation alpha if using two prompts."
"A value of 0 uses prompt_a fully, a value of 1 uses prompt_b fully",
default=0.5,
ge=0,
le=1,
),
num_inference_steps: int = Input(
description="Number of steps to run the diffusion model", default=50, ge=1
),
seed_image_id: str = Input(
description="Seed spectrogram to use", default="vibes", choices=SEED_IMAGES
),
) -> Output:
"""
Runs riffusion inference.
"""
# Load the seed image by ID
init_image_path = Path(SEED_IMAGES_DIR, f"{seed_image_id}.png")
if not init_image_path.is_file():
return Output(error=f"Invalid seed image: {seed_image_id}")
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# fake max ints
seed_a = np.random.randint(0, 2147483647)
seed_b = np.random.randint(0, 2147483647)
start = PromptInput(prompt=prompt_a, seed=seed_a, denoising=denoising)
if not prompt_b: # no transition
prompt_b = prompt_a
alpha = 0
end = PromptInput(prompt=prompt_b, seed=seed_b, denoising=denoising)
riffusion_input = InferenceInput(
start=start,
end=end,
alpha=alpha,
num_inference_steps=num_inference_steps,
seed_image_id=seed_image_id,
)
# Execute the model to get the spectrogram image
image = self.model.riffuse(riffusion_input, init_image=init_image, mask_image=None)
# Reconstruct audio from the image
params = SpectrogramParams()
converter = SpectrogramImageConverter(params=params, device=self.device)
segment = converter.audio_from_spectrogram_image(image)
if not os.path.exists("out/"):
os.mkdir("out")
out_img_path = "out/spectrogram.jpg"
image.save("out/spectrogram.jpg", exif=image.getexif())
out_wav_path = "out/gen_sound.wav"
segment.export(out_wav_path, format="wav")
return Output(audio=Path(out_wav_path), spectrogram=Path(out_img_path))
# TODO(hayk): Can we get rid of the below functions and incorporate into
# RiffusionPipeline.load_checkpoint?
def download_weights(checkpoint: str):
"""
Clears local cache & downloads riffusion weights
"""
for folder in [MODEL_CACHE, UNET_CACHE]:
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
model, unet_file = _load_model(checkpoint, local_only=False)
return model, unet_file
def _load_model(checkpoint: str, local_only=False):
model = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
cache_dir=MODEL_CACHE,
local_files_only=local_only,
)
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1",
filename="unet_traced.pt",
subfolder="unet_traced",
cache_dir=UNET_CACHE,
local_files_only=local_only,
)
return model, unet_file
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--download_weights", action="store_true", help="Download and cache weights"
)
args = parser.parse_args()
if args.download_weights:
download_weights(MODEL_ID)

View File

@ -58,6 +58,10 @@ python_version = "3.10"
module = "argh.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "cog.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "diffusers.*"
ignore_missing_imports = true

View File

@ -68,6 +68,8 @@ class RiffusionPipeline(DiffusionPipeline):
channels_last: bool = False,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
local_files_only: bool = False,
low_cpu_mem_usage: bool = False,
) -> RiffusionPipeline:
"""
Load the riffusion model pipeline.
@ -77,6 +79,8 @@ class RiffusionPipeline(DiffusionPipeline):
use_traced_unet: Whether to use the traced unet for speedups
device: Device to load the model on
channels_last: Whether to use channels_last memory format
local_files_only: Don't download, only use local files
low_cpu_mem_usage: Attempt to use less memory on CPU
"""
device = torch_util.check_device(device)
@ -91,8 +95,8 @@ class RiffusionPipeline(DiffusionPipeline):
# Disable the NSFW filter, causes incorrect false positives
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
safety_checker=lambda images, **kwargs: (images, False),
# Optionally attempt to use less memory
low_cpu_mem_usage=False,
low_cpu_mem_usage=low_cpu_mem_usage,
local_files_only=local_files_only,
).to(device)
if channels_last: