Add support for cache (still disabled)
This commit is contained in:
parent
511defae99
commit
ae34fb388e
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
||||||
import typing as T
|
import typing as T
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class PromptInput:
|
class PromptInput:
|
||||||
"""
|
"""
|
||||||
Parameters for one end of interpolation.
|
Parameters for one end of interpolation.
|
||||||
|
@ -25,7 +25,7 @@ class PromptInput:
|
||||||
guidance: float = 7.0
|
guidance: float = 7.0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class InferenceInput:
|
class InferenceInput:
|
||||||
"""
|
"""
|
||||||
Parameters for a single run of the riffusion model, interpolating between
|
Parameters for a single run of the riffusion model, interpolating between
|
||||||
|
@ -53,7 +53,7 @@ class InferenceInput:
|
||||||
mask_image_id: T.Optional[str] = None
|
mask_image_id: T.Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class InferenceOutput:
|
class InferenceOutput:
|
||||||
"""
|
"""
|
||||||
Response from the model inference server.
|
Response from the model inference server.
|
||||||
|
|
|
@ -4,6 +4,7 @@ Inference server for the riffusion project.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
@ -13,6 +14,7 @@ import typing as T
|
||||||
|
|
||||||
import dacite
|
import dacite
|
||||||
import flask
|
import flask
|
||||||
|
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
@ -116,16 +118,29 @@ def run_inference():
|
||||||
logging.info(json_data)
|
logging.info(json_data)
|
||||||
return str(exception), 400
|
return str(exception), 400
|
||||||
|
|
||||||
|
response = compute(inputs)
|
||||||
|
|
||||||
|
# Log the total time
|
||||||
|
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# TODO(hayk): Enable cache here.
|
||||||
|
# @functools.lru_cache()
|
||||||
|
def compute(inputs: InferenceInput) -> str:
|
||||||
|
"""
|
||||||
|
Does all the heavy lifting of the request.
|
||||||
|
"""
|
||||||
# Load the seed image by ID
|
# Load the seed image by ID
|
||||||
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
|
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
|
||||||
if not init_image_path.is_file:
|
if not init_image_path.is_file():
|
||||||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||||
|
|
||||||
# Load the mask image by ID
|
# Load the mask image by ID
|
||||||
if inputs.mask_image_id:
|
if inputs.mask_image_id:
|
||||||
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
|
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
|
||||||
if not mask_image_path.is_file:
|
if not mask_image_path.is_file():
|
||||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||||
else:
|
else:
|
||||||
|
@ -148,12 +163,10 @@ def run_inference():
|
||||||
duration_s=duration_s,
|
duration_s=duration_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log the total time
|
|
||||||
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
|
||||||
|
|
||||||
return flask.jsonify(dataclasses.asdict(output))
|
return flask.jsonify(dataclasses.asdict(output))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Convert a PIL image into bytes of the given image format.
|
Convert a PIL image into bytes of the given image format.
|
||||||
|
|
Loading…
Reference in New Issue