84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
"""
|
|
This file can be used to build a Truss for deployment with Baseten.
|
|
If used, it should be renamed to model.py and placed alongside the other
|
|
files from /riffusion in the standard /model directory of the Truss.
|
|
|
|
For more on the Truss file format, see https://truss.baseten.co/
|
|
"""
|
|
|
|
import typing as T
|
|
|
|
import dacite
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from riffusion.datatypes import InferenceInput
|
|
from riffusion.riffusion_pipeline import RiffusionPipeline
|
|
from riffusion.server import compute_request
|
|
|
|
|
|
class Model:
|
|
"""
|
|
Baseten Truss model class for riffusion.
|
|
|
|
See: https://truss.baseten.co/reference/structure#model.py
|
|
"""
|
|
|
|
def __init__(self, **kwargs) -> None:
|
|
self._data_dir = kwargs["data_dir"]
|
|
self._config = kwargs["config"]
|
|
self._pipeline = None
|
|
self._vae = None
|
|
|
|
self.checkpoint_name = "riffusion/riffusion-model-v1"
|
|
|
|
# Download entire seed image folder from huggingface hub
|
|
self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png")
|
|
|
|
def load(self):
|
|
"""
|
|
Load the model. Guaranteed to be called before `predict`.
|
|
"""
|
|
self._pipeline = RiffusionPipeline.load_checkpoint(
|
|
checkpoint=self.checkpoint_name,
|
|
use_traced_unet=True,
|
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
)
|
|
|
|
def preprocess(self, request: T.Dict) -> T.Dict:
|
|
"""
|
|
Incorporate pre-processing required by the model if desired here.
|
|
|
|
These might be feature transformations that are tightly coupled to the model.
|
|
"""
|
|
return request
|
|
|
|
def predict(self, request: T.Dict) -> T.Dict[str, T.List]:
|
|
"""
|
|
This is the main function that is called.
|
|
"""
|
|
assert self._pipeline is not None, "Model pipeline not loaded"
|
|
|
|
try:
|
|
inputs = dacite.from_dict(InferenceInput, request)
|
|
except dacite.exceptions.WrongTypeError as exception:
|
|
return str(exception), 400
|
|
except dacite.exceptions.MissingValueError as exception:
|
|
return str(exception), 400
|
|
|
|
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
|
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
|
|
response = compute_request(
|
|
inputs=inputs,
|
|
pipeline=self._pipeline,
|
|
seed_images_dir=self._seed_images_dir,
|
|
)
|
|
|
|
return response
|
|
|
|
def postprocess(self, request: T.Dict) -> T.Dict:
|
|
"""
|
|
Incorporate post-processing required by the model if desired here.
|
|
"""
|
|
return request
|