Make traced unet optional and support custom checkpoints
This commit is contained in:
parent
40e1e51c6a
commit
8349ccff59
|
@ -8,6 +8,7 @@ numpy
|
||||||
pillow
|
pillow
|
||||||
pydub
|
pydub
|
||||||
scipy
|
scipy
|
||||||
|
soundfile
|
||||||
torch
|
torch
|
||||||
torchaudio
|
torchaudio
|
||||||
transformers
|
transformers
|
||||||
|
|
|
@ -4,7 +4,6 @@ Inference server for the riffusion project.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
|
||||||
import logging
|
import logging
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
@ -45,6 +44,7 @@ SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
|
||||||
def run_app(
|
def run_app(
|
||||||
*,
|
*,
|
||||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||||
|
no_traced_unet: bool = False,
|
||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 3000,
|
port: int = 3000,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
@ -56,7 +56,7 @@ def run_app(
|
||||||
"""
|
"""
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
global MODEL
|
global MODEL
|
||||||
MODEL = load_model(checkpoint=checkpoint)
|
MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet)
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
debug=debug,
|
debug=debug,
|
||||||
|
@ -72,7 +72,7 @@ def run_app(
|
||||||
app.run(**args)
|
app.run(**args)
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint: str):
|
def load_model(checkpoint: str, traced_unet: bool = True):
|
||||||
"""
|
"""
|
||||||
Load the riffusion model pipeline.
|
Load the riffusion model pipeline.
|
||||||
"""
|
"""
|
||||||
|
@ -86,13 +86,15 @@ def load_model(checkpoint: str):
|
||||||
safety_checker=lambda images, **kwargs: (images, False),
|
safety_checker=lambda images, **kwargs: (images, False),
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
|
|
||||||
|
# Set the traced unet if desired
|
||||||
|
if checkpoint == "riffusion/riffusion-model-v1" and traced_unet:
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class UNet2DConditionOutput:
|
class UNet2DConditionOutput:
|
||||||
sample: torch.FloatTensor
|
sample: torch.FloatTensor
|
||||||
|
|
||||||
# Using traced unet from hf hub
|
# Using traced unet from hf hub
|
||||||
unet_file = hf_hub_download(
|
unet_file = hf_hub_download(
|
||||||
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
|
checkpoint, filename="unet_traced.pt", subfolder="unet_traced"
|
||||||
)
|
)
|
||||||
unet_traced = torch.jit.load(unet_file)
|
unet_traced = torch.jit.load(unet_file)
|
||||||
|
|
||||||
|
@ -151,8 +153,6 @@ def run_inference():
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# TODO(hayk): Enable cache here.
|
|
||||||
# @functools.lru_cache()
|
|
||||||
def compute(inputs: InferenceInput) -> str:
|
def compute(inputs: InferenceInput) -> str:
|
||||||
"""
|
"""
|
||||||
Does all the heavy lifting of the request.
|
Does all the heavy lifting of the request.
|
||||||
|
|
Loading…
Reference in New Issue