traced unet and disabled autocast, adding baseten.py
This commit is contained in:
parent
d601c5e2bc
commit
afa49af2b0
|
@ -0,0 +1,176 @@
|
|||
"""
|
||||
This file can be used to build a Truss for deployment with Baseten.
|
||||
If used, it should be renamed to model.py and placed alongside the other
|
||||
files from /riffusion in the standard /model directory of the Truss.
|
||||
|
||||
For more on the Truss file format, see https://truss.baseten.co/
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Dict, List
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
from pathlib import Path
|
||||
import PIL
|
||||
import torch
|
||||
import dacite
|
||||
import json
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
|
||||
from .datatypes import InferenceInput, InferenceOutput
|
||||
from .riffusion_pipeline import RiffusionPipeline
|
||||
|
||||
class Model:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._data_dir = kwargs["data_dir"]
|
||||
self._config = kwargs["config"]
|
||||
self._model = None
|
||||
self._vae = None
|
||||
|
||||
# Download entire seed image folder from huggingface hub
|
||||
self._seed_images_dir = snapshot_download("riffusion/riffusion-model-v1", allow_patterns="*.png")
|
||||
|
||||
print(self._seed_images_dir)
|
||||
|
||||
def load(self):
|
||||
# Load Riffusion model here and assign to self._model.
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if torch.cuda.is_available() == False:
|
||||
# Use only if you don't have a GPU with fp16 support
|
||||
self._model = RiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
else:
|
||||
# Model loading the model with fp16. This will fail if ran without a GPU with fp16 support
|
||||
pipe = RiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
# Deliberately not implementing channels_Last as it resulted in slower inference pipeline
|
||||
# pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
# Use traced unet from hf hub
|
||||
unet_file = hf_hub_download("riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced")
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = pipe.unet.in_channels
|
||||
self.device = pipe.unet.device
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
pipe.unet = TracedUNet()
|
||||
|
||||
self._model = pipe
|
||||
|
||||
|
||||
def preprocess(self, request: Dict) -> Dict:
|
||||
"""
|
||||
Incorporate pre-processing required by the model if desired here.
|
||||
|
||||
These might be feature transformations that are tightly coupled to the model.
|
||||
"""
|
||||
return request
|
||||
|
||||
def postprocess(self, request: Dict) -> Dict:
|
||||
"""
|
||||
Incorporate post-processing required by the model if desired here.
|
||||
"""
|
||||
return request
|
||||
|
||||
def predict(self, request: Dict) -> Dict[str, List]:
|
||||
"""
|
||||
This is the main function that is called.
|
||||
"""
|
||||
# Example request:
|
||||
# {"alpha":0.25,"num_inference_steps":50,"seed_image_id":"og_beat","mask_image_id":None,"start":{"prompt":"lo-fi beat for the holidays","seed":906295,"denoising":0.75,"guidance":7},"end":{"prompt":"lo-fi beat for the holidays","seed":906296,"denoising":0.75,"guidance":7}}
|
||||
|
||||
# Parse an InferenceInput dataclass from the payload
|
||||
try:
|
||||
inputs = dacite.from_dict(InferenceInput, request)
|
||||
except dacite.exceptions.WrongTypeError as exception:
|
||||
# logging.info(json_data)
|
||||
return str(exception), 400
|
||||
except dacite.exceptions.MissingValueError as exception:
|
||||
# logging.info(json_data)
|
||||
return str(exception), 400
|
||||
|
||||
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
||||
with torch.inference_mode() and torch.cuda.amp.autocast(enabled = False):
|
||||
response = self.compute(inputs)
|
||||
|
||||
return response
|
||||
|
||||
def compute(self, inputs: InferenceInput) -> str:
|
||||
"""
|
||||
Does all the heavy lifting of the request.
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.seed_image_id}.png")
|
||||
|
||||
# Convert init_image_path blob to a file
|
||||
print(init_image_path)
|
||||
|
||||
if not init_image_path.is_file():
|
||||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Load the mask image by ID
|
||||
if inputs.mask_image_id:
|
||||
mask_image_path = Path(self._seed_images_dir, f"{inputs.mask_image_id}.png")
|
||||
if not mask_image_path.is_file():
|
||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||
else:
|
||||
mask_image = None
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = self._model.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
||||
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
||||
|
||||
# Compute the output as base64 encoded strings
|
||||
image_bytes = self.image_bytes_from_image(image, mode="JPEG")
|
||||
|
||||
# Assemble the output dataclass
|
||||
output = InferenceOutput(
|
||||
image="data:image/jpeg;base64," + self.base64_encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + self.base64_encode(mp3_bytes),
|
||||
duration_s=duration_s,
|
||||
)
|
||||
|
||||
return json.dumps(dataclasses.asdict(output))
|
||||
|
||||
def image_bytes_from_image(self, image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
||||
"""
|
||||
Convert a PIL image into bytes of the given image format.
|
||||
"""
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, mode)
|
||||
image_bytes.seek(0)
|
||||
return image_bytes
|
||||
|
||||
def base64_encode(self, buffer: io.BytesIO) -> str:
|
||||
"""
|
||||
Encode the given buffer as base64.
|
||||
"""
|
||||
return base64.encodebytes(buffer.getvalue()).decode("ascii")
|
|
@ -72,7 +72,6 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return embed
|
||||
|
||||
@torch.autocast("cuda")
|
||||
@torch.no_grad()
|
||||
def riffuse(
|
||||
self,
|
||||
|
|
|
@ -19,6 +19,8 @@ from flask_cors import CORS
|
|||
import PIL
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .audio import wav_bytes_from_spectrogram_image
|
||||
from .audio import mp3_bytes_from_wav_bytes
|
||||
from .datatypes import InferenceInput
|
||||
|
@ -84,6 +86,34 @@ def load_model(checkpoint: str):
|
|||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
)
|
||||
|
||||
model = RiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
# Using traced unet from hf hub
|
||||
unet_file = hf_hub_download("riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced")
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = model.unet.in_channels
|
||||
self.device = model.unet.device
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
model.unet = TracedUNet()
|
||||
|
||||
model = model.to("cuda")
|
||||
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue