Cog integration
Support hosting on Replicate using the cog interface. Continuation of: https://github.com/riffusion/riffusion/pull/26 Topic: cog_integration
This commit is contained in:
parent
4bec0a40e0
commit
45d55e986c
|
@ -0,0 +1,38 @@
|
|||
# Configuration for Cog ⚙️
|
||||
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
||||
|
||||
build:
|
||||
# set to true if your model requires a GPU
|
||||
gpu: true
|
||||
|
||||
# a list of ubuntu apt packages to install
|
||||
system_packages:
|
||||
- "ffmpeg"
|
||||
- "libsndfile1"
|
||||
|
||||
# python version in the form '3.8' or '3.8.12'
|
||||
python_version: "3.9"
|
||||
|
||||
# a list of packages in the format <package-name>==<version>
|
||||
python_packages:
|
||||
- "accelerate==0.15.0"
|
||||
- "argh==0.26.2"
|
||||
- "dacite==1.6.0"
|
||||
- "diffusers==0.10.2"
|
||||
- "flask_cors==3.0.10"
|
||||
- "flask==1.1.2"
|
||||
- "numpy==1.19.4"
|
||||
- "pillow==8.2.0"
|
||||
- "pydub==0.25.1"
|
||||
- "scipy==1.6.3"
|
||||
- "torch==1.13.0"
|
||||
- "torchaudio==0.13.0"
|
||||
- "transformers==4.25.1"
|
||||
|
||||
# commands run after the environment is setup
|
||||
# run:
|
||||
# - "echo env is ready!"
|
||||
# - "echo another command if needed"
|
||||
|
||||
# predict.py defines how predictions are run on your model
|
||||
predict: "integrations/cog_riffusion.py:RiffusionPredictor"
|
|
@ -0,0 +1,181 @@
|
|||
"""
|
||||
Prediction interface for Cog ⚙️
|
||||
https://github.com/replicate/cog/blob/main/docs/python.md
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from cog import BaseModel, BasePredictor, Input, Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from riffusion.datatypes import InferenceInput, PromptInput
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
MODEL_ID = "riffusion/riffusion-model-v1"
|
||||
MODEL_CACHE = "riffusion-cache"
|
||||
UNET_CACHE = "unet-cache"
|
||||
|
||||
# Where built-in seed images are stored
|
||||
SEED_IMAGES_DIR = Path("./seed_images")
|
||||
SEED_IMAGES = [val.split(".")[0] for val in os.listdir(SEED_IMAGES_DIR) if "png" in val]
|
||||
SEED_IMAGES.sort()
|
||||
|
||||
|
||||
class Output(BaseModel):
|
||||
"""
|
||||
Output class for riffusion predictions
|
||||
"""
|
||||
|
||||
audio: Path
|
||||
spectrogram: Path
|
||||
error: T.Optional[str] = None
|
||||
|
||||
|
||||
class RiffusionPredictor(BasePredictor):
|
||||
"""
|
||||
Implementation of cog predictor object s.t. we can run riffusion predictions w/cog.
|
||||
|
||||
See README & https://github.com/replicate/cog for details
|
||||
"""
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Loads the model onto GPU from local cache.
|
||||
"""
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.model = RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=MODEL_ID,
|
||||
use_traced_unet=True,
|
||||
device=self.device,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
prompt_a: str = Input(description="The prompt for your audio", default="funky synth solo"),
|
||||
denoising: float = Input(
|
||||
description="How much to transform input spectrogram",
|
||||
default=0.75,
|
||||
ge=0,
|
||||
le=1,
|
||||
),
|
||||
prompt_b: str = Input(
|
||||
description="The second prompt to interpolate with the first,"
|
||||
"leave blank if no interpolation",
|
||||
default=None,
|
||||
),
|
||||
alpha: float = Input(
|
||||
description="Interpolation alpha if using two prompts."
|
||||
"A value of 0 uses prompt_a fully, a value of 1 uses prompt_b fully",
|
||||
default=0.5,
|
||||
ge=0,
|
||||
le=1,
|
||||
),
|
||||
num_inference_steps: int = Input(
|
||||
description="Number of steps to run the diffusion model", default=50, ge=1
|
||||
),
|
||||
seed_image_id: str = Input(
|
||||
description="Seed spectrogram to use", default="vibes", choices=SEED_IMAGES
|
||||
),
|
||||
) -> Output:
|
||||
"""
|
||||
Runs riffusion inference.
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(SEED_IMAGES_DIR, f"{seed_image_id}.png")
|
||||
if not init_image_path.is_file():
|
||||
return Output(error=f"Invalid seed image: {seed_image_id}")
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# fake max ints
|
||||
seed_a = np.random.randint(0, 2147483647)
|
||||
seed_b = np.random.randint(0, 2147483647)
|
||||
|
||||
start = PromptInput(prompt=prompt_a, seed=seed_a, denoising=denoising)
|
||||
if not prompt_b: # no transition
|
||||
prompt_b = prompt_a
|
||||
alpha = 0
|
||||
end = PromptInput(prompt=prompt_b, seed=seed_b, denoising=denoising)
|
||||
riffusion_input = InferenceInput(
|
||||
start=start,
|
||||
end=end,
|
||||
alpha=alpha,
|
||||
num_inference_steps=num_inference_steps,
|
||||
seed_image_id=seed_image_id,
|
||||
)
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = self.model.riffuse(riffusion_input, init_image=init_image, mask_image=None)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
params = SpectrogramParams()
|
||||
converter = SpectrogramImageConverter(params=params, device=self.device)
|
||||
segment = converter.audio_from_spectrogram_image(image)
|
||||
|
||||
if not os.path.exists("out/"):
|
||||
os.mkdir("out")
|
||||
|
||||
out_img_path = "out/spectrogram.jpg"
|
||||
image.save("out/spectrogram.jpg", exif=image.getexif())
|
||||
|
||||
out_wav_path = "out/gen_sound.wav"
|
||||
segment.export(out_wav_path, format="wav")
|
||||
|
||||
return Output(audio=Path(out_wav_path), spectrogram=Path(out_img_path))
|
||||
|
||||
|
||||
# TODO(hayk): Can we get rid of the below functions and incorporate into
|
||||
# RiffusionPipeline.load_checkpoint?
|
||||
|
||||
|
||||
def download_weights(checkpoint: str):
|
||||
"""
|
||||
Clears local cache & downloads riffusion weights
|
||||
"""
|
||||
for folder in [MODEL_CACHE, UNET_CACHE]:
|
||||
if os.path.exists(folder):
|
||||
shutil.rmtree(folder)
|
||||
os.makedirs(folder)
|
||||
|
||||
model, unet_file = _load_model(checkpoint, local_only=False)
|
||||
return model, unet_file
|
||||
|
||||
|
||||
def _load_model(checkpoint: str, local_only=False):
|
||||
model = RiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
cache_dir=MODEL_CACHE,
|
||||
local_files_only=local_only,
|
||||
)
|
||||
|
||||
unet_file = hf_hub_download(
|
||||
"riffusion/riffusion-model-v1",
|
||||
filename="unet_traced.pt",
|
||||
subfolder="unet_traced",
|
||||
cache_dir=UNET_CACHE,
|
||||
local_files_only=local_only,
|
||||
)
|
||||
return model, unet_file
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--download_weights", action="store_true", help="Download and cache weights"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.download_weights:
|
||||
download_weights(MODEL_ID)
|
|
@ -58,6 +58,10 @@ python_version = "3.10"
|
|||
module = "argh.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "cog.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "diffusers.*"
|
||||
ignore_missing_imports = true
|
||||
|
|
|
@ -68,6 +68,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
channels_last: bool = False,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: str = "cuda",
|
||||
local_files_only: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> RiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion model pipeline.
|
||||
|
@ -77,6 +79,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
use_traced_unet: Whether to use the traced unet for speedups
|
||||
device: Device to load the model on
|
||||
channels_last: Whether to use channels_last memory format
|
||||
local_files_only: Don't download, only use local files
|
||||
low_cpu_mem_usage: Attempt to use less memory on CPU
|
||||
"""
|
||||
device = torch_util.check_device(device)
|
||||
|
||||
|
@ -91,8 +95,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
# Disable the NSFW filter, causes incorrect false positives
|
||||
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
# Optionally attempt to use less memory
|
||||
low_cpu_mem_usage=False,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
local_files_only=local_files_only,
|
||||
).to(device)
|
||||
|
||||
if channels_last:
|
||||
|
|
Loading…
Reference in New Issue