diff --git a/riffusion/__init__.py b/riffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/riffusion/server.py b/riffusion/server.py new file mode 100644 index 0000000..ff7a65c --- /dev/null +++ b/riffusion/server.py @@ -0,0 +1,153 @@ +""" +Inference server for the riffusion project. +""" + +import base64 +import dataclasses +import logging +import io +import json +from pathlib import Path +import time + +import dacite +import flask +from flask_cors import CORS +import PIL +import torch + +from .audio import wav_bytes_from_spectrogram_image +from .audio import mp3_bytes_from_wav_bytes +from .datatypes import InferenceInput +from .datatypes import InferenceOutput +from .riffusion_pipeline import RiffusionPipeline + +# Flask app with CORS +app = flask.Flask(__name__) +CORS(app) + +# Log at the INFO level to both stdout and disk +logging.basicConfig(level=logging.INFO) +logging.getLogger().addHandler(logging.FileHandler("server.log")) + +# Global variable for the model pipeline +MODEL = None + +# Where built-in seed images are stored +SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images") + + +def run_app( + *, + checkpoint: str, + host: str = "127.0.0.1", + port: int = 3000, + debug: bool = False, +): + """ + Run a flask API that serves the given riffusion model checkpoint. + """ + # Initialize the model + global MODEL + MODEL = load_model(checkpoint=checkpoint) + + app.run( + debug=debug, + threaded=False, + host=host, + port=port, + ) + + +def load_model(checkpoint: str): + """ + Load the riffusion model pipeline. + """ + assert torch.cuda.is_available() + + model = RiffusionPipeline.from_pretrained( + checkpoint, + revision="fp16", + torch_dtype=torch.float16, + # Disable the NSFW filter, causes incorrect false positives + safety_checker=lambda images, **kwargs: (images, False), + ) + + model = model.to("cuda") + + return model + + +@app.route("/run_inference/", methods=["POST"]) +def run_inference(): + """ + Execute the riffusion model as an API. + + Inputs: + Serialized JSON of the InferenceInput dataclass + + Returns: + Serialized JSON of the InferenceOutput dataclass + """ + start_time = time.time() + + # Parse the payload as JSON + json_data = json.loads(flask.request.data) + + # Log the request + logging.info(json_data) + + # Parse an InferenceInput dataclass from the payload + try: + inputs = dacite.from_dict(InferenceInput, json_data) + except dacite.exceptions.WrongTypeError as exception: + logging.info(json_data) + return str(exception), 400 + except dacite.exceptions.MissingValueError as exception: + logging.info(json_data) + return str(exception), 400 + + # Load the seed image by ID + init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png") + if not init_image_path.is_file: + return f"Invalid seed image: {inputs.seed_image_id}", 400 + init_image = PIL.Image.open(str(init_image_path)).convert("RGB") + + # Execute the model to get the spectrogram image + image = MODEL.riffuse(inputs, init_image=init_image) + + # Reconstruct audio from the image + wav_bytes = wav_bytes_from_spectrogram_image(image) + mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes) + + # Compute the output as base64 encoded strings + image_bytes = image_bytes_from_image(image, mode="PNG") + output = InferenceOutput(image=base64_encode(image_bytes), audio=base64_encode(mp3_bytes)) + + # Log the total time + logging.info(f"Request took {time.time() - start_time:.2f} s") + + return flask.jsonify(dataclasses.asdict(output)) + + +def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO: + """ + Convert a PIL image into bytes of the given image format. + """ + image_bytes = io.BytesIO() + image.save(image_bytes, mode) + image_bytes.seek(0) + return image_bytes + + +def base64_encode(buffer: io.BytesIO) -> str: + """ + Encode the given buffer as base64. + """ + return base64.encodebytes(buffer.getvalue()).decode("ascii") + + +if __name__ == "__main__": + import argh + + argh.dispatch_command(run_app)