176 lines
6.6 KiB
Python
176 lines
6.6 KiB
Python
"""
|
|
This file can be used to build a Truss for deployment with Baseten.
|
|
If used, it should be renamed to model.py and placed alongside the other
|
|
files from /riffusion in the standard /model directory of the Truss.
|
|
|
|
For more on the Truss file format, see https://truss.baseten.co/
|
|
"""
|
|
|
|
import base64
|
|
import dataclasses
|
|
import json
|
|
import io
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
import PIL
|
|
import torch
|
|
import dacite
|
|
|
|
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
|
from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
|
|
from .datatypes import InferenceInput, InferenceOutput
|
|
from .riffusion_pipeline import RiffusionPipeline
|
|
|
|
|
|
class Model:
|
|
def __init__(self, **kwargs) -> None:
|
|
self._data_dir = kwargs["data_dir"]
|
|
self._config = kwargs["config"]
|
|
self._model = None
|
|
self._vae = None
|
|
|
|
# Download entire seed image folder from huggingface hub
|
|
self._seed_images_dir = snapshot_download(
|
|
"riffusion/riffusion-model-v1", allow_patterns="*.png"
|
|
)
|
|
|
|
def load(self):
|
|
# Load Riffusion model here and assign to self._model.
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
if torch.cuda.is_available() == False:
|
|
# Use only if you don't have a GPU with fp16 support
|
|
self._model = RiffusionPipeline.from_pretrained(
|
|
"riffusion/riffusion-model-v1",
|
|
safety_checker=lambda images, **kwargs: (images, False),
|
|
).to(device)
|
|
else:
|
|
# Model loading the model with fp16. This will fail if ran without a GPU with fp16 support
|
|
pipe = RiffusionPipeline.from_pretrained(
|
|
"riffusion/riffusion-model-v1",
|
|
revision="fp16",
|
|
torch_dtype=torch.float16,
|
|
# Disable the NSFW filter, causes incorrect false positives
|
|
safety_checker=lambda images, **kwargs: (images, False),
|
|
).to(device)
|
|
|
|
# Deliberately not implementing channels_Last as it resulted in slower inference pipeline
|
|
# pipe.unet.to(memory_format=torch.channels_last)
|
|
|
|
@dataclasses.dataclass
|
|
class UNet2DConditionOutput:
|
|
sample: torch.FloatTensor
|
|
|
|
# Use traced unet from hf hub
|
|
unet_file = hf_hub_download(
|
|
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
|
|
)
|
|
unet_traced = torch.jit.load(unet_file)
|
|
|
|
class TracedUNet(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.in_channels = pipe.unet.in_channels
|
|
self.device = pipe.unet.device
|
|
|
|
def forward(self, latent_model_input, t, encoder_hidden_states):
|
|
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
|
return UNet2DConditionOutput(sample=sample)
|
|
|
|
pipe.unet = TracedUNet()
|
|
|
|
self._model = pipe
|
|
|
|
def preprocess(self, request: Dict) -> Dict:
|
|
"""
|
|
Incorporate pre-processing required by the model if desired here.
|
|
|
|
These might be feature transformations that are tightly coupled to the model.
|
|
"""
|
|
return request
|
|
|
|
def postprocess(self, request: Dict) -> Dict:
|
|
"""
|
|
Incorporate post-processing required by the model if desired here.
|
|
"""
|
|
return request
|
|
|
|
def predict(self, request: Dict) -> Dict[str, List]:
|
|
"""
|
|
This is the main function that is called.
|
|
"""
|
|
# Example request:
|
|
# {"alpha":0.25,"num_inference_steps":50,"seed_image_id":"og_beat","mask_image_id":None,"start":{"prompt":"lo-fi beat for the holidays","seed":906295,"denoising":0.75,"guidance":7},"end":{"prompt":"lo-fi beat for the holidays","seed":906296,"denoising":0.75,"guidance":7}}
|
|
|
|
# Parse an InferenceInput dataclass from the payload
|
|
try:
|
|
inputs = dacite.from_dict(InferenceInput, request)
|
|
except dacite.exceptions.WrongTypeError as exception:
|
|
# logging.info(json_data)
|
|
return str(exception), 400
|
|
except dacite.exceptions.MissingValueError as exception:
|
|
# logging.info(json_data)
|
|
return str(exception), 400
|
|
|
|
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
|
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
|
|
response = self.compute(inputs)
|
|
|
|
return response
|
|
|
|
def compute(self, inputs: InferenceInput) -> str:
|
|
"""
|
|
Does all the heavy lifting of the request.
|
|
"""
|
|
# Load the seed image by ID
|
|
init_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.seed_image_id}.png")
|
|
|
|
if not init_image_path.is_file():
|
|
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
|
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
|
|
|
# Load the mask image by ID
|
|
if inputs.mask_image_id:
|
|
mask_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.mask_image_id}.png")
|
|
if not mask_image_path.is_file():
|
|
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
|
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
|
else:
|
|
mask_image = None
|
|
|
|
# Execute the model to get the spectrogram image
|
|
image = self._model.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
|
|
|
# Reconstruct audio from the image
|
|
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
|
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
|
|
|
# Compute the output as base64 encoded strings
|
|
image_bytes = self.image_bytes_from_image(image, mode="JPEG")
|
|
|
|
# Assemble the output dataclass
|
|
output = InferenceOutput(
|
|
image="data:image/jpeg;base64," + self.base64_encode(image_bytes),
|
|
audio="data:audio/mpeg;base64," + self.base64_encode(mp3_bytes),
|
|
duration_s=duration_s,
|
|
)
|
|
|
|
return json.dumps(dataclasses.asdict(output))
|
|
|
|
def image_bytes_from_image(self, image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
|
"""
|
|
Convert a PIL image into bytes of the given image format.
|
|
"""
|
|
image_bytes = io.BytesIO()
|
|
image.save(image_bytes, mode)
|
|
image_bytes.seek(0)
|
|
return image_bytes
|
|
|
|
def base64_encode(self, buffer: io.BytesIO) -> str:
|
|
"""
|
|
Encode the given buffer as base64.
|
|
"""
|
|
return base64.encodebytes(buffer.getvalue()).decode("ascii")
|