diff --git a/riffusion/datatypes.py b/riffusion/datatypes.py index b24a903..6a8cdb8 100644 --- a/riffusion/datatypes.py +++ b/riffusion/datatypes.py @@ -6,7 +6,7 @@ from dataclasses import dataclass import typing as T -@dataclass +@dataclass(frozen=True) class PromptInput: """ Parameters for one end of interpolation. @@ -25,7 +25,7 @@ class PromptInput: guidance: float = 7.0 -@dataclass +@dataclass(frozen=True) class InferenceInput: """ Parameters for a single run of the riffusion model, interpolating between @@ -53,7 +53,7 @@ class InferenceInput: mask_image_id: T.Optional[str] = None -@dataclass +@dataclass(frozen=True) class InferenceOutput: """ Response from the model inference server. diff --git a/riffusion/server.py b/riffusion/server.py index 4abcfad..ff7d026 100644 --- a/riffusion/server.py +++ b/riffusion/server.py @@ -4,6 +4,7 @@ Inference server for the riffusion project. import base64 import dataclasses +import functools import logging import io import json @@ -13,6 +14,7 @@ import typing as T import dacite import flask + from flask_cors import CORS import PIL import torch @@ -116,16 +118,29 @@ def run_inference(): logging.info(json_data) return str(exception), 400 + response = compute(inputs) + + # Log the total time + logging.info(f"Request took {time.time() - start_time:.2f} s") + + return response + +# TODO(hayk): Enable cache here. +# @functools.lru_cache() +def compute(inputs: InferenceInput) -> str: + """ + Does all the heavy lifting of the request. + """ # Load the seed image by ID init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png") - if not init_image_path.is_file: + if not init_image_path.is_file(): return f"Invalid seed image: {inputs.seed_image_id}", 400 init_image = PIL.Image.open(str(init_image_path)).convert("RGB") # Load the mask image by ID if inputs.mask_image_id: mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png") - if not mask_image_path.is_file: + if not mask_image_path.is_file(): return f"Invalid mask image: {inputs.mask_image_id}", 400 mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB") else: @@ -148,12 +163,10 @@ def run_inference(): duration_s=duration_s, ) - # Log the total time - logging.info(f"Request took {time.time() - start_time:.2f} s") - return flask.jsonify(dataclasses.asdict(output)) + def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO: """ Convert a PIL image into bytes of the given image format.