Add support for cache (still disabled)

This commit is contained in:
Hayk Martiros 2022-11-28 00:06:12 +00:00
parent 511defae99
commit ae34fb388e
2 changed files with 21 additions and 8 deletions

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
import typing as T import typing as T
@dataclass @dataclass(frozen=True)
class PromptInput: class PromptInput:
""" """
Parameters for one end of interpolation. Parameters for one end of interpolation.
@ -25,7 +25,7 @@ class PromptInput:
guidance: float = 7.0 guidance: float = 7.0
@dataclass @dataclass(frozen=True)
class InferenceInput: class InferenceInput:
""" """
Parameters for a single run of the riffusion model, interpolating between Parameters for a single run of the riffusion model, interpolating between
@ -53,7 +53,7 @@ class InferenceInput:
mask_image_id: T.Optional[str] = None mask_image_id: T.Optional[str] = None
@dataclass @dataclass(frozen=True)
class InferenceOutput: class InferenceOutput:
""" """
Response from the model inference server. Response from the model inference server.

View File

@ -4,6 +4,7 @@ Inference server for the riffusion project.
import base64 import base64
import dataclasses import dataclasses
import functools
import logging import logging
import io import io
import json import json
@ -13,6 +14,7 @@ import typing as T
import dacite import dacite
import flask import flask
from flask_cors import CORS from flask_cors import CORS
import PIL import PIL
import torch import torch
@ -116,16 +118,29 @@ def run_inference():
logging.info(json_data) logging.info(json_data)
return str(exception), 400 return str(exception), 400
response = compute(inputs)
# Log the total time
logging.info(f"Request took {time.time() - start_time:.2f} s")
return response
# TODO(hayk): Enable cache here.
# @functools.lru_cache()
def compute(inputs: InferenceInput) -> str:
"""
Does all the heavy lifting of the request.
"""
# Load the seed image by ID # Load the seed image by ID
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png") init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
if not init_image_path.is_file: if not init_image_path.is_file():
return f"Invalid seed image: {inputs.seed_image_id}", 400 return f"Invalid seed image: {inputs.seed_image_id}", 400
init_image = PIL.Image.open(str(init_image_path)).convert("RGB") init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# Load the mask image by ID # Load the mask image by ID
if inputs.mask_image_id: if inputs.mask_image_id:
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png") mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
if not mask_image_path.is_file: if not mask_image_path.is_file():
return f"Invalid mask image: {inputs.mask_image_id}", 400 return f"Invalid mask image: {inputs.mask_image_id}", 400
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB") mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
else: else:
@ -148,12 +163,10 @@ def run_inference():
duration_s=duration_s, duration_s=duration_s,
) )
# Log the total time
logging.info(f"Request took {time.time() - start_time:.2f} s")
return flask.jsonify(dataclasses.asdict(output)) return flask.jsonify(dataclasses.asdict(output))
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO: def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
""" """
Convert a PIL image into bytes of the given image format. Convert a PIL image into bytes of the given image format.