Make traced unet optional and support custom checkpoints

This commit is contained in:
Hayk Martiros 2022-12-23 05:08:37 +00:00
parent 40e1e51c6a
commit 8349ccff59
2 changed files with 24 additions and 23 deletions

View File

@ -8,6 +8,7 @@ numpy
pillow pillow
pydub pydub
scipy scipy
soundfile
torch torch
torchaudio torchaudio
transformers transformers

View File

@ -4,7 +4,6 @@ Inference server for the riffusion project.
import base64 import base64
import dataclasses import dataclasses
import functools
import logging import logging
import io import io
import json import json
@ -45,6 +44,7 @@ SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
def run_app( def run_app(
*, *,
checkpoint: str = "riffusion/riffusion-model-v1", checkpoint: str = "riffusion/riffusion-model-v1",
no_traced_unet: bool = False,
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 3000, port: int = 3000,
debug: bool = False, debug: bool = False,
@ -56,7 +56,7 @@ def run_app(
""" """
# Initialize the model # Initialize the model
global MODEL global MODEL
MODEL = load_model(checkpoint=checkpoint) MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet)
args = dict( args = dict(
debug=debug, debug=debug,
@ -72,7 +72,7 @@ def run_app(
app.run(**args) app.run(**args)
def load_model(checkpoint: str): def load_model(checkpoint: str, traced_unet: bool = True):
""" """
Load the riffusion model pipeline. Load the riffusion model pipeline.
""" """
@ -86,13 +86,15 @@ def load_model(checkpoint: str):
safety_checker=lambda images, **kwargs: (images, False), safety_checker=lambda images, **kwargs: (images, False),
).to("cuda") ).to("cuda")
# Set the traced unet if desired
if checkpoint == "riffusion/riffusion-model-v1" and traced_unet:
@dataclasses.dataclass @dataclasses.dataclass
class UNet2DConditionOutput: class UNet2DConditionOutput:
sample: torch.FloatTensor sample: torch.FloatTensor
# Using traced unet from hf hub # Using traced unet from hf hub
unet_file = hf_hub_download( unet_file = hf_hub_download(
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced" checkpoint, filename="unet_traced.pt", subfolder="unet_traced"
) )
unet_traced = torch.jit.load(unet_file) unet_traced = torch.jit.load(unet_file)
@ -151,8 +153,6 @@ def run_inference():
return response return response
# TODO(hayk): Enable cache here.
# @functools.lru_cache()
def compute(inputs: InferenceInput) -> str: def compute(inputs: InferenceInput) -> str:
""" """
Does all the heavy lifting of the request. Does all the heavy lifting of the request.