Greatly simplify the server and baseten integration

With the new clean module structure, make it so the two servers
share all the important code. This makes the baseten integration
very small and simple, and paves the way for more integrations.

Topic: clean_rewrite
This commit is contained in:
Hayk Martiros 2022-12-26 17:25:17 -08:00
parent 40a799a3d3
commit cbf473216b
6 changed files with 155 additions and 266 deletions

3
integrations/README.md Normal file
View File

@ -0,0 +1,3 @@
# Integrations
This package contains integrations of Riffusion into third party apps and deployments.

0
integrations/__init__.py Normal file
View File

84
integrations/baseten.py Normal file
View File

@ -0,0 +1,84 @@
"""
This file can be used to build a Truss for deployment with Baseten.
If used, it should be renamed to model.py and placed alongside the other
files from /riffusion in the standard /model directory of the Truss.
For more on the Truss file format, see https://truss.baseten.co/
"""
import typing as T
import torch
import dacite
from huggingface_hub import snapshot_download
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.server import compute_request
from riffusion.datatypes import InferenceInput
class Model:
"""
Baseten Truss model class for riffusion.
See: https://truss.baseten.co/reference/structure#model.py
"""
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._pipeline = None
self._vae = None
self.checkpoint_name = "riffusion/riffusion-model-v1"
# Download entire seed image folder from huggingface hub
self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png")
def load(self):
"""
Load the model. Guaranteed to be called before `predict`.
"""
self._pipeline = RiffusionPipeline.load_checkpoint(
checkpoint=self.checkpoint_name,
use_traced_unet=True,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
def preprocess(self, request: T.Dict) -> T.Dict:
"""
Incorporate pre-processing required by the model if desired here.
These might be feature transformations that are tightly coupled to the model.
"""
return request
def predict(self, request: T.Dict) -> T.Dict[str, T.List]:
"""
This is the main function that is called.
"""
assert self._pipeline is not None, "Model pipeline not loaded"
try:
inputs = dacite.from_dict(InferenceInput, request)
except dacite.exceptions.WrongTypeError as exception:
return str(exception), 400
except dacite.exceptions.MissingValueError as exception:
return str(exception), 400
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
response = compute_request(
inputs=inputs,
pipeline=self._pipeline,
seed_images_dir=self._seed_images_dir,
)
return response
def postprocess(self, request: T.Dict) -> T.Dict:
"""
Incorporate post-processing required by the model if desired here.
"""
return request

View File

@ -1,175 +0,0 @@
"""
This file can be used to build a Truss for deployment with Baseten.
If used, it should be renamed to model.py and placed alongside the other
files from /riffusion in the standard /model directory of the Truss.
For more on the Truss file format, see https://truss.baseten.co/
"""
import base64
import dataclasses
import json
import io
from pathlib import Path
from typing import Dict, List
import PIL
import torch
import dacite
from huggingface_hub import hf_hub_download, snapshot_download
from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
from .datatypes import InferenceInput, InferenceOutput
from .riffusion_pipeline import RiffusionPipeline
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._model = None
self._vae = None
# Download entire seed image folder from huggingface hub
self._seed_images_dir = snapshot_download(
"riffusion/riffusion-model-v1", allow_patterns="*.png"
)
def load(self):
# Load Riffusion model here and assign to self._model.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() == False:
# Use only if you don't have a GPU with fp16 support
self._model = RiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1",
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
else:
# Model loading the model with fp16. This will fail if ran without a GPU with fp16 support
pipe = RiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1",
revision="fp16",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
# Deliberately not implementing channels_Last as it resulted in slower inference pipeline
# pipe.unet.to(memory_format=torch.channels_last)
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
# Use traced unet from hf hub
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
pipe.unet = TracedUNet()
self._model = pipe
def preprocess(self, request: Dict) -> Dict:
"""
Incorporate pre-processing required by the model if desired here.
These might be feature transformations that are tightly coupled to the model.
"""
return request
def postprocess(self, request: Dict) -> Dict:
"""
Incorporate post-processing required by the model if desired here.
"""
return request
def predict(self, request: Dict) -> Dict[str, List]:
"""
This is the main function that is called.
"""
# Example request:
# {"alpha":0.25,"num_inference_steps":50,"seed_image_id":"og_beat","mask_image_id":None,"start":{"prompt":"lo-fi beat for the holidays","seed":906295,"denoising":0.75,"guidance":7},"end":{"prompt":"lo-fi beat for the holidays","seed":906296,"denoising":0.75,"guidance":7}}
# Parse an InferenceInput dataclass from the payload
try:
inputs = dacite.from_dict(InferenceInput, request)
except dacite.exceptions.WrongTypeError as exception:
# logging.info(json_data)
return str(exception), 400
except dacite.exceptions.MissingValueError as exception:
# logging.info(json_data)
return str(exception), 400
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
response = self.compute(inputs)
return response
def compute(self, inputs: InferenceInput) -> str:
"""
Does all the heavy lifting of the request.
"""
# Load the seed image by ID
init_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.seed_image_id}.png")
if not init_image_path.is_file():
return f"Invalid seed image: {inputs.seed_image_id}", 400
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# Load the mask image by ID
if inputs.mask_image_id:
mask_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.mask_image_id}.png")
if not mask_image_path.is_file():
return f"Invalid mask image: {inputs.mask_image_id}", 400
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
else:
mask_image = None
# Execute the model to get the spectrogram image
image = self._model.riffuse(inputs, init_image=init_image, mask_image=mask_image)
# Reconstruct audio from the image
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
# Compute the output as base64 encoded strings
image_bytes = self.image_bytes_from_image(image, mode="JPEG")
# Assemble the output dataclass
output = InferenceOutput(
image="data:image/jpeg;base64," + self.base64_encode(image_bytes),
audio="data:audio/mpeg;base64," + self.base64_encode(mp3_bytes),
duration_s=duration_s,
)
return json.dumps(dataclasses.asdict(output))
def image_bytes_from_image(self, image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
"""
Convert a PIL image into bytes of the given image format.
"""
image_bytes = io.BytesIO()
image.save(image_bytes, mode)
image_bytes.seek(0)
return image_bytes
def base64_encode(self, buffer: io.BytesIO) -> str:
"""
Encode the given buffer as base64.
"""
return base64.encodebytes(buffer.getvalue()).decode("ascii")

View File

@ -1,6 +1,7 @@
"""
Data model for the riffusion API.
"""
from __future__ import annotations
from dataclasses import dataclass
import typing as T
@ -58,6 +59,7 @@ class InferenceOutput:
"""
Response from the model inference server.
"""
# base64 encoded spectrogram image as a JPEG
image: str

View File

@ -1,8 +1,7 @@
"""
Inference server for the riffusion project.
Flask server that serves the riffusion model as an API.
"""
import base64
import dataclasses
import logging
import io
@ -16,15 +15,13 @@ import flask
from flask_cors import CORS
import PIL
import torch
from huggingface_hub import hf_hub_download
from .audio import wav_bytes_from_spectrogram_image
from .audio import mp3_bytes_from_wav_bytes
from .datatypes import InferenceInput
from .datatypes import InferenceOutput
from .riffusion_pipeline import RiffusionPipeline
from riffusion.datatypes import InferenceInput
from riffusion.datatypes import InferenceOutput
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.util import base64_util
# Flask app with CORS
app = flask.Flask(__name__)
@ -35,7 +32,7 @@ logging.basicConfig(level=logging.INFO)
logging.getLogger().addHandler(logging.FileHandler("server.log"))
# Global variable for the model pipeline
MODEL = None
PIPELINE: T.Optional[RiffusionPipeline] = None
# Where built-in seed images are stored
SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
@ -45,8 +42,9 @@ def run_app(
*,
checkpoint: str = "riffusion/riffusion-model-v1",
no_traced_unet: bool = False,
device: str = "cuda",
host: str = "127.0.0.1",
port: int = 3000,
port: int = 3013,
debug: bool = False,
ssl_certificate: T.Optional[str] = None,
ssl_key: T.Optional[str] = None,
@ -55,8 +53,12 @@ def run_app(
Run a flask API that serves the given riffusion model checkpoint.
"""
# Initialize the model
global MODEL
MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet)
global PIPELINE
PIPELINE = RiffusionPipeline.load_checkpoint(
checkpoint=checkpoint,
use_traced_unet=not no_traced_unet,
device=device,
)
args = dict(
debug=debug,
@ -69,51 +71,7 @@ def run_app(
assert ssl_key is not None
args["ssl_context"] = (ssl_certificate, ssl_key)
app.run(**args)
def load_model(checkpoint: str, traced_unet: bool = True):
"""
Load the riffusion model pipeline.
"""
assert torch.cuda.is_available()
model = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
).to("cuda")
# Set the traced unet if desired
if checkpoint == "riffusion/riffusion-model-v1" and traced_unet:
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
# Using traced unet from hf hub
unet_file = hf_hub_download(
checkpoint, filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = model.unet.in_channels
self.device = model.unet.device
self.dtype = torch.float16
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
model.unet = TracedUNet()
model = model.to("cuda")
return model
app.run(**args) # type: ignore
@app.route("/run_inference/", methods=["POST"])
@ -145,7 +103,11 @@ def run_inference():
logging.info(json_data)
return str(exception), 400
response = compute(inputs)
response = compute_request(
inputs=inputs,
seed_images_dir=SEED_IMAGES_DIR,
pipeline=PIPELINE,
)
# Log the total time
logging.info(f"Request took {time.time() - start_time:.2f} s")
@ -153,60 +115,73 @@ def run_inference():
return response
def compute(inputs: InferenceInput) -> str:
def compute_request(
inputs: InferenceInput,
pipeline: RiffusionPipeline,
seed_images_dir: str,
) -> T.Union[str, T.Tuple[str, int]]:
"""
Does all the heavy lifting of the request.
Args:
inputs: The input dataclass
pipeline: The riffusion model pipeline
seed_images_dir: The directory where seed images are stored
"""
# Load the seed image by ID
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png")
if not init_image_path.is_file():
return f"Invalid seed image: {inputs.seed_image_id}", 400
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
# Load the mask image by ID
mask_image: T.Optional[PIL.Image.Image] = None
if inputs.mask_image_id:
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png")
if not mask_image_path.is_file():
return f"Invalid mask image: {inputs.mask_image_id}", 400
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
else:
mask_image = None
# Execute the model to get the spectrogram image
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
image = pipeline.riffuse(
inputs,
init_image=init_image,
mask_image=mask_image,
)
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
)
# Reconstruct audio from the image
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
# TODO(hayk): It may help performance to cache this object
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
segment = converter.audio_from_spectrogram_image(
image,
apply_filters=True,
)
# Compute the output as base64 encoded strings
image_bytes = image_bytes_from_image(image, mode="JPEG")
# Export audio to MP3 bytes
mp3_bytes = io.BytesIO()
segment.export(mp3_bytes, format="mp3")
mp3_bytes.seek(0)
# Export image to JPEG bytes
image_bytes = io.BytesIO()
image.save(image_bytes, exif=image.getexif(), format="JPEG")
image_bytes.seek(0)
# Assemble the output dataclass
output = InferenceOutput(
image="data:image/jpeg;base64," + base64_encode(image_bytes),
audio="data:audio/mpeg;base64," + base64_encode(mp3_bytes),
duration_s=duration_s,
image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
duration_s=segment.duration_seconds,
)
return flask.jsonify(dataclasses.asdict(output))
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
"""
Convert a PIL image into bytes of the given image format.
"""
image_bytes = io.BytesIO()
image.save(image_bytes, mode)
image_bytes.seek(0)
return image_bytes
def base64_encode(buffer: io.BytesIO) -> str:
"""
Encode the given buffer as base64.
"""
return base64.encodebytes(buffer.getvalue()).decode("ascii")
return json.dumps(dataclasses.asdict(output))
if __name__ == "__main__":