191 lines
5.1 KiB
Python
191 lines
5.1 KiB
Python
"""
|
|
Flask server that serves the riffusion model as an API.
|
|
"""
|
|
|
|
import dataclasses
|
|
import logging
|
|
import io
|
|
import json
|
|
from pathlib import Path
|
|
import time
|
|
import typing as T
|
|
|
|
import dacite
|
|
import flask
|
|
|
|
from flask_cors import CORS
|
|
import PIL
|
|
|
|
from riffusion.datatypes import InferenceInput
|
|
from riffusion.datatypes import InferenceOutput
|
|
from riffusion.riffusion_pipeline import RiffusionPipeline
|
|
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
|
from riffusion.spectrogram_params import SpectrogramParams
|
|
from riffusion.util import base64_util
|
|
|
|
# Flask app with CORS
|
|
app = flask.Flask(__name__)
|
|
CORS(app)
|
|
|
|
# Log at the INFO level to both stdout and disk
|
|
logging.basicConfig(level=logging.INFO)
|
|
logging.getLogger().addHandler(logging.FileHandler("server.log"))
|
|
|
|
# Global variable for the model pipeline
|
|
PIPELINE: T.Optional[RiffusionPipeline] = None
|
|
|
|
# Where built-in seed images are stored
|
|
SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
|
|
|
|
|
|
def run_app(
|
|
*,
|
|
checkpoint: str = "riffusion/riffusion-model-v1",
|
|
no_traced_unet: bool = False,
|
|
device: str = "cuda",
|
|
host: str = "127.0.0.1",
|
|
port: int = 3013,
|
|
debug: bool = False,
|
|
ssl_certificate: T.Optional[str] = None,
|
|
ssl_key: T.Optional[str] = None,
|
|
):
|
|
"""
|
|
Run a flask API that serves the given riffusion model checkpoint.
|
|
"""
|
|
# Initialize the model
|
|
global PIPELINE
|
|
PIPELINE = RiffusionPipeline.load_checkpoint(
|
|
checkpoint=checkpoint,
|
|
use_traced_unet=not no_traced_unet,
|
|
device=device,
|
|
)
|
|
|
|
args = dict(
|
|
debug=debug,
|
|
threaded=False,
|
|
host=host,
|
|
port=port,
|
|
)
|
|
|
|
if ssl_certificate:
|
|
assert ssl_key is not None
|
|
args["ssl_context"] = (ssl_certificate, ssl_key)
|
|
|
|
app.run(**args) # type: ignore
|
|
|
|
|
|
@app.route("/run_inference/", methods=["POST"])
|
|
def run_inference():
|
|
"""
|
|
Execute the riffusion model as an API.
|
|
|
|
Inputs:
|
|
Serialized JSON of the InferenceInput dataclass
|
|
|
|
Returns:
|
|
Serialized JSON of the InferenceOutput dataclass
|
|
"""
|
|
start_time = time.time()
|
|
|
|
# Parse the payload as JSON
|
|
json_data = json.loads(flask.request.data)
|
|
|
|
# Log the request
|
|
logging.info(json_data)
|
|
|
|
# Parse an InferenceInput dataclass from the payload
|
|
try:
|
|
inputs = dacite.from_dict(InferenceInput, json_data)
|
|
except dacite.exceptions.WrongTypeError as exception:
|
|
logging.info(json_data)
|
|
return str(exception), 400
|
|
except dacite.exceptions.MissingValueError as exception:
|
|
logging.info(json_data)
|
|
return str(exception), 400
|
|
|
|
response = compute_request(
|
|
inputs=inputs,
|
|
seed_images_dir=SEED_IMAGES_DIR,
|
|
pipeline=PIPELINE,
|
|
)
|
|
|
|
# Log the total time
|
|
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
|
|
|
return response
|
|
|
|
|
|
def compute_request(
|
|
inputs: InferenceInput,
|
|
pipeline: RiffusionPipeline,
|
|
seed_images_dir: str,
|
|
) -> T.Union[str, T.Tuple[str, int]]:
|
|
"""
|
|
Does all the heavy lifting of the request.
|
|
|
|
Args:
|
|
inputs: The input dataclass
|
|
pipeline: The riffusion model pipeline
|
|
seed_images_dir: The directory where seed images are stored
|
|
"""
|
|
# Load the seed image by ID
|
|
init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png")
|
|
|
|
if not init_image_path.is_file():
|
|
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
|
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
|
|
|
# Load the mask image by ID
|
|
mask_image: T.Optional[PIL.Image.Image] = None
|
|
if inputs.mask_image_id:
|
|
mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png")
|
|
if not mask_image_path.is_file():
|
|
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
|
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
|
|
|
# Execute the model to get the spectrogram image
|
|
image = pipeline.riffuse(
|
|
inputs,
|
|
init_image=init_image,
|
|
mask_image=mask_image,
|
|
)
|
|
|
|
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
|
params = SpectrogramParams(
|
|
min_frequency=0,
|
|
max_frequency=10000,
|
|
)
|
|
|
|
# Reconstruct audio from the image
|
|
# TODO(hayk): It may help performance to cache this object
|
|
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
|
|
segment = converter.audio_from_spectrogram_image(
|
|
image,
|
|
apply_filters=True,
|
|
)
|
|
|
|
# Export audio to MP3 bytes
|
|
mp3_bytes = io.BytesIO()
|
|
segment.export(mp3_bytes, format="mp3")
|
|
mp3_bytes.seek(0)
|
|
|
|
# Export image to JPEG bytes
|
|
image_bytes = io.BytesIO()
|
|
image.save(image_bytes, exif=image.getexif(), format="JPEG")
|
|
image_bytes.seek(0)
|
|
|
|
# Assemble the output dataclass
|
|
output = InferenceOutput(
|
|
image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
|
|
audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
|
|
duration_s=segment.duration_seconds,
|
|
)
|
|
|
|
return json.dumps(dataclasses.asdict(output))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argh
|
|
|
|
argh.dispatch_command(run_app)
|