Add support for cache (still disabled)
This commit is contained in:
parent
511defae99
commit
ae34fb388e
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
|||
import typing as T
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class PromptInput:
|
||||
"""
|
||||
Parameters for one end of interpolation.
|
||||
|
@ -25,7 +25,7 @@ class PromptInput:
|
|||
guidance: float = 7.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class InferenceInput:
|
||||
"""
|
||||
Parameters for a single run of the riffusion model, interpolating between
|
||||
|
@ -53,7 +53,7 @@ class InferenceInput:
|
|||
mask_image_id: T.Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class InferenceOutput:
|
||||
"""
|
||||
Response from the model inference server.
|
||||
|
|
|
@ -4,6 +4,7 @@ Inference server for the riffusion project.
|
|||
|
||||
import base64
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import io
|
||||
import json
|
||||
|
@ -13,6 +14,7 @@ import typing as T
|
|||
|
||||
import dacite
|
||||
import flask
|
||||
|
||||
from flask_cors import CORS
|
||||
import PIL
|
||||
import torch
|
||||
|
@ -116,16 +118,29 @@ def run_inference():
|
|||
logging.info(json_data)
|
||||
return str(exception), 400
|
||||
|
||||
response = compute(inputs)
|
||||
|
||||
# Log the total time
|
||||
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
||||
|
||||
return response
|
||||
|
||||
# TODO(hayk): Enable cache here.
|
||||
# @functools.lru_cache()
|
||||
def compute(inputs: InferenceInput) -> str:
|
||||
"""
|
||||
Does all the heavy lifting of the request.
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
|
||||
if not init_image_path.is_file:
|
||||
if not init_image_path.is_file():
|
||||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Load the mask image by ID
|
||||
if inputs.mask_image_id:
|
||||
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
|
||||
if not mask_image_path.is_file:
|
||||
if not mask_image_path.is_file():
|
||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||
else:
|
||||
|
@ -148,12 +163,10 @@ def run_inference():
|
|||
duration_s=duration_s,
|
||||
)
|
||||
|
||||
# Log the total time
|
||||
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
||||
|
||||
return flask.jsonify(dataclasses.asdict(output))
|
||||
|
||||
|
||||
|
||||
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
||||
"""
|
||||
Convert a PIL image into bytes of the given image format.
|
||||
|
|
Loading…
Reference in New Issue