From 8349ccff5957f42d8ae7838b6d8218e3060ad1ee Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Fri, 23 Dec 2022 05:08:37 +0000 Subject: [PATCH] Make traced unet optional and support custom checkpoints --- requirements.txt | 1 + riffusion/server.py | 46 ++++++++++++++++++++++----------------------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/requirements.txt b/requirements.txt index 835861d..57edf91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ numpy pillow pydub scipy +soundfile torch torchaudio transformers diff --git a/riffusion/server.py b/riffusion/server.py index c931f6e..8a8bc5c 100644 --- a/riffusion/server.py +++ b/riffusion/server.py @@ -4,7 +4,6 @@ Inference server for the riffusion project. import base64 import dataclasses -import functools import logging import io import json @@ -45,6 +44,7 @@ SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images") def run_app( *, checkpoint: str = "riffusion/riffusion-model-v1", + no_traced_unet: bool = False, host: str = "127.0.0.1", port: int = 3000, debug: bool = False, @@ -56,7 +56,7 @@ def run_app( """ # Initialize the model global MODEL - MODEL = load_model(checkpoint=checkpoint) + MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet) args = dict( debug=debug, @@ -72,7 +72,7 @@ def run_app( app.run(**args) -def load_model(checkpoint: str): +def load_model(checkpoint: str, traced_unet: bool = True): """ Load the riffusion model pipeline. """ @@ -86,28 +86,30 @@ def load_model(checkpoint: str): safety_checker=lambda images, **kwargs: (images, False), ).to("cuda") - @dataclasses.dataclass - class UNet2DConditionOutput: - sample: torch.FloatTensor + # Set the traced unet if desired + if checkpoint == "riffusion/riffusion-model-v1" and traced_unet: + @dataclasses.dataclass + class UNet2DConditionOutput: + sample: torch.FloatTensor - # Using traced unet from hf hub - unet_file = hf_hub_download( - "riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced" - ) - unet_traced = torch.jit.load(unet_file) + # Using traced unet from hf hub + unet_file = hf_hub_download( + checkpoint, filename="unet_traced.pt", subfolder="unet_traced" + ) + unet_traced = torch.jit.load(unet_file) - class TracedUNet(torch.nn.Module): - def __init__(self): - super().__init__() - self.in_channels = model.unet.in_channels - self.device = model.unet.device - self.dtype = torch.float16 + class TracedUNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.in_channels = model.unet.in_channels + self.device = model.unet.device + self.dtype = torch.float16 - def forward(self, latent_model_input, t, encoder_hidden_states): - sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] - return UNet2DConditionOutput(sample=sample) + def forward(self, latent_model_input, t, encoder_hidden_states): + sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] + return UNet2DConditionOutput(sample=sample) - model.unet = TracedUNet() + model.unet = TracedUNet() model = model.to("cuda") @@ -151,8 +153,6 @@ def run_inference(): return response -# TODO(hayk): Enable cache here. -# @functools.lru_cache() def compute(inputs: InferenceInput) -> str: """ Does all the heavy lifting of the request.