diff --git a/integrations/README.md b/integrations/README.md new file mode 100644 index 0000000..7193e54 --- /dev/null +++ b/integrations/README.md @@ -0,0 +1,3 @@ +# Integrations + +This package contains integrations of Riffusion into third party apps and deployments. diff --git a/integrations/__init__.py b/integrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/integrations/baseten.py b/integrations/baseten.py new file mode 100644 index 0000000..18af805 --- /dev/null +++ b/integrations/baseten.py @@ -0,0 +1,84 @@ +""" +This file can be used to build a Truss for deployment with Baseten. +If used, it should be renamed to model.py and placed alongside the other +files from /riffusion in the standard /model directory of the Truss. + +For more on the Truss file format, see https://truss.baseten.co/ +""" + +import typing as T + +import torch +import dacite + +from huggingface_hub import snapshot_download + +from riffusion.riffusion_pipeline import RiffusionPipeline +from riffusion.server import compute_request +from riffusion.datatypes import InferenceInput + + +class Model: + """ + Baseten Truss model class for riffusion. + + See: https://truss.baseten.co/reference/structure#model.py + """ + + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._pipeline = None + self._vae = None + + self.checkpoint_name = "riffusion/riffusion-model-v1" + + # Download entire seed image folder from huggingface hub + self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png") + + def load(self): + """ + Load the model. Guaranteed to be called before `predict`. + """ + self._pipeline = RiffusionPipeline.load_checkpoint( + checkpoint=self.checkpoint_name, + use_traced_unet=True, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + + def preprocess(self, request: T.Dict) -> T.Dict: + """ + Incorporate pre-processing required by the model if desired here. + + These might be feature transformations that are tightly coupled to the model. + """ + return request + + def predict(self, request: T.Dict) -> T.Dict[str, T.List]: + """ + This is the main function that is called. + """ + assert self._pipeline is not None, "Model pipeline not loaded" + + try: + inputs = dacite.from_dict(InferenceInput, request) + except dacite.exceptions.WrongTypeError as exception: + return str(exception), 400 + except dacite.exceptions.MissingValueError as exception: + return str(exception), 400 + + # NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4 + with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False): + response = compute_request( + inputs=inputs, + pipeline=self._pipeline, + seed_images_dir=self._seed_images_dir, + ) + + return response + + def postprocess(self, request: T.Dict) -> T.Dict: + """ + Incorporate post-processing required by the model if desired here. + """ + return request diff --git a/riffusion/baseten.py b/riffusion/baseten.py deleted file mode 100644 index db99dc8..0000000 --- a/riffusion/baseten.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -This file can be used to build a Truss for deployment with Baseten. -If used, it should be renamed to model.py and placed alongside the other -files from /riffusion in the standard /model directory of the Truss. - -For more on the Truss file format, see https://truss.baseten.co/ -""" - -import base64 -import dataclasses -import json -import io -from pathlib import Path -from typing import Dict, List - -import PIL -import torch -import dacite - -from huggingface_hub import hf_hub_download, snapshot_download - -from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes -from .datatypes import InferenceInput, InferenceOutput -from .riffusion_pipeline import RiffusionPipeline - - -class Model: - def __init__(self, **kwargs) -> None: - self._data_dir = kwargs["data_dir"] - self._config = kwargs["config"] - self._model = None - self._vae = None - - # Download entire seed image folder from huggingface hub - self._seed_images_dir = snapshot_download( - "riffusion/riffusion-model-v1", allow_patterns="*.png" - ) - - def load(self): - # Load Riffusion model here and assign to self._model. - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - if torch.cuda.is_available() == False: - # Use only if you don't have a GPU with fp16 support - self._model = RiffusionPipeline.from_pretrained( - "riffusion/riffusion-model-v1", - safety_checker=lambda images, **kwargs: (images, False), - ).to(device) - else: - # Model loading the model with fp16. This will fail if ran without a GPU with fp16 support - pipe = RiffusionPipeline.from_pretrained( - "riffusion/riffusion-model-v1", - revision="fp16", - torch_dtype=torch.float16, - # Disable the NSFW filter, causes incorrect false positives - safety_checker=lambda images, **kwargs: (images, False), - ).to(device) - - # Deliberately not implementing channels_Last as it resulted in slower inference pipeline - # pipe.unet.to(memory_format=torch.channels_last) - - @dataclasses.dataclass - class UNet2DConditionOutput: - sample: torch.FloatTensor - - # Use traced unet from hf hub - unet_file = hf_hub_download( - "riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced" - ) - unet_traced = torch.jit.load(unet_file) - - class TracedUNet(torch.nn.Module): - def __init__(self): - super().__init__() - self.in_channels = pipe.unet.in_channels - self.device = pipe.unet.device - - def forward(self, latent_model_input, t, encoder_hidden_states): - sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] - return UNet2DConditionOutput(sample=sample) - - pipe.unet = TracedUNet() - - self._model = pipe - - def preprocess(self, request: Dict) -> Dict: - """ - Incorporate pre-processing required by the model if desired here. - - These might be feature transformations that are tightly coupled to the model. - """ - return request - - def postprocess(self, request: Dict) -> Dict: - """ - Incorporate post-processing required by the model if desired here. - """ - return request - - def predict(self, request: Dict) -> Dict[str, List]: - """ - This is the main function that is called. - """ - # Example request: - # {"alpha":0.25,"num_inference_steps":50,"seed_image_id":"og_beat","mask_image_id":None,"start":{"prompt":"lo-fi beat for the holidays","seed":906295,"denoising":0.75,"guidance":7},"end":{"prompt":"lo-fi beat for the holidays","seed":906296,"denoising":0.75,"guidance":7}} - - # Parse an InferenceInput dataclass from the payload - try: - inputs = dacite.from_dict(InferenceInput, request) - except dacite.exceptions.WrongTypeError as exception: - # logging.info(json_data) - return str(exception), 400 - except dacite.exceptions.MissingValueError as exception: - # logging.info(json_data) - return str(exception), 400 - - # NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4 - with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False): - response = self.compute(inputs) - - return response - - def compute(self, inputs: InferenceInput) -> str: - """ - Does all the heavy lifting of the request. - """ - # Load the seed image by ID - init_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.seed_image_id}.png") - - if not init_image_path.is_file(): - return f"Invalid seed image: {inputs.seed_image_id}", 400 - init_image = PIL.Image.open(str(init_image_path)).convert("RGB") - - # Load the mask image by ID - if inputs.mask_image_id: - mask_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.mask_image_id}.png") - if not mask_image_path.is_file(): - return f"Invalid mask image: {inputs.mask_image_id}", 400 - mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB") - else: - mask_image = None - - # Execute the model to get the spectrogram image - image = self._model.riffuse(inputs, init_image=init_image, mask_image=mask_image) - - # Reconstruct audio from the image - wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image) - mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes) - - # Compute the output as base64 encoded strings - image_bytes = self.image_bytes_from_image(image, mode="JPEG") - - # Assemble the output dataclass - output = InferenceOutput( - image="data:image/jpeg;base64," + self.base64_encode(image_bytes), - audio="data:audio/mpeg;base64," + self.base64_encode(mp3_bytes), - duration_s=duration_s, - ) - - return json.dumps(dataclasses.asdict(output)) - - def image_bytes_from_image(self, image: PIL.Image, mode: str = "PNG") -> io.BytesIO: - """ - Convert a PIL image into bytes of the given image format. - """ - image_bytes = io.BytesIO() - image.save(image_bytes, mode) - image_bytes.seek(0) - return image_bytes - - def base64_encode(self, buffer: io.BytesIO) -> str: - """ - Encode the given buffer as base64. - """ - return base64.encodebytes(buffer.getvalue()).decode("ascii") diff --git a/riffusion/datatypes.py b/riffusion/datatypes.py index 6a8cdb8..4f10557 100644 --- a/riffusion/datatypes.py +++ b/riffusion/datatypes.py @@ -1,6 +1,7 @@ """ Data model for the riffusion API. """ +from __future__ import annotations from dataclasses import dataclass import typing as T @@ -58,6 +59,7 @@ class InferenceOutput: """ Response from the model inference server. """ + # base64 encoded spectrogram image as a JPEG image: str diff --git a/riffusion/server.py b/riffusion/server.py index 8a8bc5c..88d0d3d 100644 --- a/riffusion/server.py +++ b/riffusion/server.py @@ -1,8 +1,7 @@ """ -Inference server for the riffusion project. +Flask server that serves the riffusion model as an API. """ -import base64 import dataclasses import logging import io @@ -16,15 +15,13 @@ import flask from flask_cors import CORS import PIL -import torch -from huggingface_hub import hf_hub_download - -from .audio import wav_bytes_from_spectrogram_image -from .audio import mp3_bytes_from_wav_bytes -from .datatypes import InferenceInput -from .datatypes import InferenceOutput -from .riffusion_pipeline import RiffusionPipeline +from riffusion.datatypes import InferenceInput +from riffusion.datatypes import InferenceOutput +from riffusion.riffusion_pipeline import RiffusionPipeline +from riffusion.spectrogram_image_converter import SpectrogramImageConverter +from riffusion.spectrogram_params import SpectrogramParams +from riffusion.util import base64_util # Flask app with CORS app = flask.Flask(__name__) @@ -35,7 +32,7 @@ logging.basicConfig(level=logging.INFO) logging.getLogger().addHandler(logging.FileHandler("server.log")) # Global variable for the model pipeline -MODEL = None +PIPELINE: T.Optional[RiffusionPipeline] = None # Where built-in seed images are stored SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images") @@ -45,8 +42,9 @@ def run_app( *, checkpoint: str = "riffusion/riffusion-model-v1", no_traced_unet: bool = False, + device: str = "cuda", host: str = "127.0.0.1", - port: int = 3000, + port: int = 3013, debug: bool = False, ssl_certificate: T.Optional[str] = None, ssl_key: T.Optional[str] = None, @@ -55,8 +53,12 @@ def run_app( Run a flask API that serves the given riffusion model checkpoint. """ # Initialize the model - global MODEL - MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet) + global PIPELINE + PIPELINE = RiffusionPipeline.load_checkpoint( + checkpoint=checkpoint, + use_traced_unet=not no_traced_unet, + device=device, + ) args = dict( debug=debug, @@ -69,51 +71,7 @@ def run_app( assert ssl_key is not None args["ssl_context"] = (ssl_certificate, ssl_key) - app.run(**args) - - -def load_model(checkpoint: str, traced_unet: bool = True): - """ - Load the riffusion model pipeline. - """ - assert torch.cuda.is_available() - - model = RiffusionPipeline.from_pretrained( - checkpoint, - revision="main", - torch_dtype=torch.float16, - # Disable the NSFW filter, causes incorrect false positives - safety_checker=lambda images, **kwargs: (images, False), - ).to("cuda") - - # Set the traced unet if desired - if checkpoint == "riffusion/riffusion-model-v1" and traced_unet: - @dataclasses.dataclass - class UNet2DConditionOutput: - sample: torch.FloatTensor - - # Using traced unet from hf hub - unet_file = hf_hub_download( - checkpoint, filename="unet_traced.pt", subfolder="unet_traced" - ) - unet_traced = torch.jit.load(unet_file) - - class TracedUNet(torch.nn.Module): - def __init__(self): - super().__init__() - self.in_channels = model.unet.in_channels - self.device = model.unet.device - self.dtype = torch.float16 - - def forward(self, latent_model_input, t, encoder_hidden_states): - sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] - return UNet2DConditionOutput(sample=sample) - - model.unet = TracedUNet() - - model = model.to("cuda") - - return model + app.run(**args) # type: ignore @app.route("/run_inference/", methods=["POST"]) @@ -145,7 +103,11 @@ def run_inference(): logging.info(json_data) return str(exception), 400 - response = compute(inputs) + response = compute_request( + inputs=inputs, + seed_images_dir=SEED_IMAGES_DIR, + pipeline=PIPELINE, + ) # Log the total time logging.info(f"Request took {time.time() - start_time:.2f} s") @@ -153,60 +115,73 @@ def run_inference(): return response -def compute(inputs: InferenceInput) -> str: +def compute_request( + inputs: InferenceInput, + pipeline: RiffusionPipeline, + seed_images_dir: str, +) -> T.Union[str, T.Tuple[str, int]]: """ Does all the heavy lifting of the request. + + Args: + inputs: The input dataclass + pipeline: The riffusion model pipeline + seed_images_dir: The directory where seed images are stored """ # Load the seed image by ID - init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png") + init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png") + if not init_image_path.is_file(): return f"Invalid seed image: {inputs.seed_image_id}", 400 init_image = PIL.Image.open(str(init_image_path)).convert("RGB") # Load the mask image by ID + mask_image: T.Optional[PIL.Image.Image] = None if inputs.mask_image_id: - mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png") + mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png") if not mask_image_path.is_file(): return f"Invalid mask image: {inputs.mask_image_id}", 400 mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB") - else: - mask_image = None # Execute the model to get the spectrogram image - image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image) + image = pipeline.riffuse( + inputs, + init_image=init_image, + mask_image=mask_image, + ) + + # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained + params = SpectrogramParams( + min_frequency=0, + max_frequency=10000, + ) # Reconstruct audio from the image - wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image) - mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes) + # TODO(hayk): It may help performance to cache this object + converter = SpectrogramImageConverter(params=params, device=str(pipeline.device)) + segment = converter.audio_from_spectrogram_image( + image, + apply_filters=True, + ) - # Compute the output as base64 encoded strings - image_bytes = image_bytes_from_image(image, mode="JPEG") + # Export audio to MP3 bytes + mp3_bytes = io.BytesIO() + segment.export(mp3_bytes, format="mp3") + mp3_bytes.seek(0) + + # Export image to JPEG bytes + image_bytes = io.BytesIO() + image.save(image_bytes, exif=image.getexif(), format="JPEG") + image_bytes.seek(0) # Assemble the output dataclass output = InferenceOutput( - image="data:image/jpeg;base64," + base64_encode(image_bytes), - audio="data:audio/mpeg;base64," + base64_encode(mp3_bytes), - duration_s=duration_s, + image="data:image/jpeg;base64," + base64_util.encode(image_bytes), + audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes), + duration_s=segment.duration_seconds, ) - return flask.jsonify(dataclasses.asdict(output)) - - -def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO: - """ - Convert a PIL image into bytes of the given image format. - """ - image_bytes = io.BytesIO() - image.save(image_bytes, mode) - image_bytes.seek(0) - return image_bytes - - -def base64_encode(buffer: io.BytesIO) -> str: - """ - Encode the given buffer as base64. - """ - return base64.encodebytes(buffer.getvalue()).decode("ascii") + return json.dumps(dataclasses.asdict(output)) if __name__ == "__main__":