Add support for cache (still disabled)

This commit is contained in:
Hayk Martiros 2022-11-28 00:06:12 +00:00
parent 511defae99
commit ae34fb388e
2 changed files with 21 additions and 8 deletions

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
import typing as T
@dataclass
@dataclass(frozen=True)
class PromptInput:
"""
Parameters for one end of interpolation.
@ -25,7 +25,7 @@ class PromptInput:
guidance: float = 7.0
@dataclass
@dataclass(frozen=True)
class InferenceInput:
"""
Parameters for a single run of the riffusion model, interpolating between
@ -53,7 +53,7 @@ class InferenceInput:
mask_image_id: T.Optional[str] = None
@dataclass
@dataclass(frozen=True)
class InferenceOutput:
"""
Response from the model inference server.

View File

@ -4,6 +4,7 @@ Inference server for the riffusion project.
import base64
import dataclasses
import functools
import logging
import io
import json
@ -13,6 +14,7 @@ import typing as T
import dacite
import flask
from flask_cors import CORS
import PIL
import torch
@ -116,16 +118,29 @@ def run_inference():
logging.info(json_data)
return str(exception), 400
response = compute(inputs)
# Log the total time
logging.info(f"Request took {time.time() - start_time:.2f} s")
return response
# TODO(hayk): Enable cache here.
# @functools.lru_cache()
def compute(inputs: InferenceInput) -> str:
"""
Does all the heavy lifting of the request.
"""
# Load the seed image by ID
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
if not init_image_path.is_file:
if not init_image_path.is_file():
return f"Invalid seed image: {inputs.seed_image_id}", 400
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# Load the mask image by ID
if inputs.mask_image_id:
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
if not mask_image_path.is_file:
if not mask_image_path.is_file():
return f"Invalid mask image: {inputs.mask_image_id}", 400
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
else:
@ -148,12 +163,10 @@ def run_inference():
duration_s=duration_s,
)
# Log the total time
logging.info(f"Request took {time.time() - start_time:.2f} s")
return flask.jsonify(dataclasses.asdict(output))
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
"""
Convert a PIL image into bytes of the given image format.