Greatly simplify the server and baseten integration
With the new clean module structure, make it so the two servers share all the important code. This makes the baseten integration very small and simple, and paves the way for more integrations. Topic: clean_rewrite
This commit is contained in:
parent
40a799a3d3
commit
cbf473216b
|
@ -0,0 +1,3 @@
|
|||
# Integrations
|
||||
|
||||
This package contains integrations of Riffusion into third party apps and deployments.
|
|
@ -0,0 +1,84 @@
|
|||
"""
|
||||
This file can be used to build a Truss for deployment with Baseten.
|
||||
If used, it should be renamed to model.py and placed alongside the other
|
||||
files from /riffusion in the standard /model directory of the Truss.
|
||||
|
||||
For more on the Truss file format, see https://truss.baseten.co/
|
||||
"""
|
||||
|
||||
import typing as T
|
||||
|
||||
import torch
|
||||
import dacite
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.server import compute_request
|
||||
from riffusion.datatypes import InferenceInput
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
Baseten Truss model class for riffusion.
|
||||
|
||||
See: https://truss.baseten.co/reference/structure#model.py
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._data_dir = kwargs["data_dir"]
|
||||
self._config = kwargs["config"]
|
||||
self._pipeline = None
|
||||
self._vae = None
|
||||
|
||||
self.checkpoint_name = "riffusion/riffusion-model-v1"
|
||||
|
||||
# Download entire seed image folder from huggingface hub
|
||||
self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png")
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
Load the model. Guaranteed to be called before `predict`.
|
||||
"""
|
||||
self._pipeline = RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=self.checkpoint_name,
|
||||
use_traced_unet=True,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
def preprocess(self, request: T.Dict) -> T.Dict:
|
||||
"""
|
||||
Incorporate pre-processing required by the model if desired here.
|
||||
|
||||
These might be feature transformations that are tightly coupled to the model.
|
||||
"""
|
||||
return request
|
||||
|
||||
def predict(self, request: T.Dict) -> T.Dict[str, T.List]:
|
||||
"""
|
||||
This is the main function that is called.
|
||||
"""
|
||||
assert self._pipeline is not None, "Model pipeline not loaded"
|
||||
|
||||
try:
|
||||
inputs = dacite.from_dict(InferenceInput, request)
|
||||
except dacite.exceptions.WrongTypeError as exception:
|
||||
return str(exception), 400
|
||||
except dacite.exceptions.MissingValueError as exception:
|
||||
return str(exception), 400
|
||||
|
||||
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
||||
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
|
||||
response = compute_request(
|
||||
inputs=inputs,
|
||||
pipeline=self._pipeline,
|
||||
seed_images_dir=self._seed_images_dir,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def postprocess(self, request: T.Dict) -> T.Dict:
|
||||
"""
|
||||
Incorporate post-processing required by the model if desired here.
|
||||
"""
|
||||
return request
|
|
@ -1,175 +0,0 @@
|
|||
"""
|
||||
This file can be used to build a Truss for deployment with Baseten.
|
||||
If used, it should be renamed to model.py and placed alongside the other
|
||||
files from /riffusion in the standard /model directory of the Truss.
|
||||
|
||||
For more on the Truss file format, see https://truss.baseten.co/
|
||||
"""
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import json
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import dacite
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
|
||||
from .datatypes import InferenceInput, InferenceOutput
|
||||
from .riffusion_pipeline import RiffusionPipeline
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._data_dir = kwargs["data_dir"]
|
||||
self._config = kwargs["config"]
|
||||
self._model = None
|
||||
self._vae = None
|
||||
|
||||
# Download entire seed image folder from huggingface hub
|
||||
self._seed_images_dir = snapshot_download(
|
||||
"riffusion/riffusion-model-v1", allow_patterns="*.png"
|
||||
)
|
||||
|
||||
def load(self):
|
||||
# Load Riffusion model here and assign to self._model.
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if torch.cuda.is_available() == False:
|
||||
# Use only if you don't have a GPU with fp16 support
|
||||
self._model = RiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
else:
|
||||
# Model loading the model with fp16. This will fail if ran without a GPU with fp16 support
|
||||
pipe = RiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
# Deliberately not implementing channels_Last as it resulted in slower inference pipeline
|
||||
# pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
# Use traced unet from hf hub
|
||||
unet_file = hf_hub_download(
|
||||
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = pipe.unet.in_channels
|
||||
self.device = pipe.unet.device
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
pipe.unet = TracedUNet()
|
||||
|
||||
self._model = pipe
|
||||
|
||||
def preprocess(self, request: Dict) -> Dict:
|
||||
"""
|
||||
Incorporate pre-processing required by the model if desired here.
|
||||
|
||||
These might be feature transformations that are tightly coupled to the model.
|
||||
"""
|
||||
return request
|
||||
|
||||
def postprocess(self, request: Dict) -> Dict:
|
||||
"""
|
||||
Incorporate post-processing required by the model if desired here.
|
||||
"""
|
||||
return request
|
||||
|
||||
def predict(self, request: Dict) -> Dict[str, List]:
|
||||
"""
|
||||
This is the main function that is called.
|
||||
"""
|
||||
# Example request:
|
||||
# {"alpha":0.25,"num_inference_steps":50,"seed_image_id":"og_beat","mask_image_id":None,"start":{"prompt":"lo-fi beat for the holidays","seed":906295,"denoising":0.75,"guidance":7},"end":{"prompt":"lo-fi beat for the holidays","seed":906296,"denoising":0.75,"guidance":7}}
|
||||
|
||||
# Parse an InferenceInput dataclass from the payload
|
||||
try:
|
||||
inputs = dacite.from_dict(InferenceInput, request)
|
||||
except dacite.exceptions.WrongTypeError as exception:
|
||||
# logging.info(json_data)
|
||||
return str(exception), 400
|
||||
except dacite.exceptions.MissingValueError as exception:
|
||||
# logging.info(json_data)
|
||||
return str(exception), 400
|
||||
|
||||
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
||||
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
|
||||
response = self.compute(inputs)
|
||||
|
||||
return response
|
||||
|
||||
def compute(self, inputs: InferenceInput) -> str:
|
||||
"""
|
||||
Does all the heavy lifting of the request.
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.seed_image_id}.png")
|
||||
|
||||
if not init_image_path.is_file():
|
||||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Load the mask image by ID
|
||||
if inputs.mask_image_id:
|
||||
mask_image_path = Path(self._seed_images_dir, f"seed_images/{inputs.mask_image_id}.png")
|
||||
if not mask_image_path.is_file():
|
||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||
else:
|
||||
mask_image = None
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = self._model.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
||||
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
||||
|
||||
# Compute the output as base64 encoded strings
|
||||
image_bytes = self.image_bytes_from_image(image, mode="JPEG")
|
||||
|
||||
# Assemble the output dataclass
|
||||
output = InferenceOutput(
|
||||
image="data:image/jpeg;base64," + self.base64_encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + self.base64_encode(mp3_bytes),
|
||||
duration_s=duration_s,
|
||||
)
|
||||
|
||||
return json.dumps(dataclasses.asdict(output))
|
||||
|
||||
def image_bytes_from_image(self, image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
||||
"""
|
||||
Convert a PIL image into bytes of the given image format.
|
||||
"""
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, mode)
|
||||
image_bytes.seek(0)
|
||||
return image_bytes
|
||||
|
||||
def base64_encode(self, buffer: io.BytesIO) -> str:
|
||||
"""
|
||||
Encode the given buffer as base64.
|
||||
"""
|
||||
return base64.encodebytes(buffer.getvalue()).decode("ascii")
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Data model for the riffusion API.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import typing as T
|
||||
|
@ -58,6 +59,7 @@ class InferenceOutput:
|
|||
"""
|
||||
Response from the model inference server.
|
||||
"""
|
||||
|
||||
# base64 encoded spectrogram image as a JPEG
|
||||
image: str
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
"""
|
||||
Inference server for the riffusion project.
|
||||
Flask server that serves the riffusion model as an API.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import logging
|
||||
import io
|
||||
|
@ -16,15 +15,13 @@ import flask
|
|||
|
||||
from flask_cors import CORS
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .audio import wav_bytes_from_spectrogram_image
|
||||
from .audio import mp3_bytes_from_wav_bytes
|
||||
from .datatypes import InferenceInput
|
||||
from .datatypes import InferenceOutput
|
||||
from .riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.datatypes import InferenceInput
|
||||
from riffusion.datatypes import InferenceOutput
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import base64_util
|
||||
|
||||
# Flask app with CORS
|
||||
app = flask.Flask(__name__)
|
||||
|
@ -35,7 +32,7 @@ logging.basicConfig(level=logging.INFO)
|
|||
logging.getLogger().addHandler(logging.FileHandler("server.log"))
|
||||
|
||||
# Global variable for the model pipeline
|
||||
MODEL = None
|
||||
PIPELINE: T.Optional[RiffusionPipeline] = None
|
||||
|
||||
# Where built-in seed images are stored
|
||||
SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
|
||||
|
@ -45,8 +42,9 @@ def run_app(
|
|||
*,
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
no_traced_unet: bool = False,
|
||||
device: str = "cuda",
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3000,
|
||||
port: int = 3013,
|
||||
debug: bool = False,
|
||||
ssl_certificate: T.Optional[str] = None,
|
||||
ssl_key: T.Optional[str] = None,
|
||||
|
@ -55,8 +53,12 @@ def run_app(
|
|||
Run a flask API that serves the given riffusion model checkpoint.
|
||||
"""
|
||||
# Initialize the model
|
||||
global MODEL
|
||||
MODEL = load_model(checkpoint=checkpoint, traced_unet=not no_traced_unet)
|
||||
global PIPELINE
|
||||
PIPELINE = RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=checkpoint,
|
||||
use_traced_unet=not no_traced_unet,
|
||||
device=device,
|
||||
)
|
||||
|
||||
args = dict(
|
||||
debug=debug,
|
||||
|
@ -69,51 +71,7 @@ def run_app(
|
|||
assert ssl_key is not None
|
||||
args["ssl_context"] = (ssl_certificate, ssl_key)
|
||||
|
||||
app.run(**args)
|
||||
|
||||
|
||||
def load_model(checkpoint: str, traced_unet: bool = True):
|
||||
"""
|
||||
Load the riffusion model pipeline.
|
||||
"""
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
model = RiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to("cuda")
|
||||
|
||||
# Set the traced unet if desired
|
||||
if checkpoint == "riffusion/riffusion-model-v1" and traced_unet:
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
# Using traced unet from hf hub
|
||||
unet_file = hf_hub_download(
|
||||
checkpoint, filename="unet_traced.pt", subfolder="unet_traced"
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = model.unet.in_channels
|
||||
self.device = model.unet.device
|
||||
self.dtype = torch.float16
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
model.unet = TracedUNet()
|
||||
|
||||
model = model.to("cuda")
|
||||
|
||||
return model
|
||||
app.run(**args) # type: ignore
|
||||
|
||||
|
||||
@app.route("/run_inference/", methods=["POST"])
|
||||
|
@ -145,7 +103,11 @@ def run_inference():
|
|||
logging.info(json_data)
|
||||
return str(exception), 400
|
||||
|
||||
response = compute(inputs)
|
||||
response = compute_request(
|
||||
inputs=inputs,
|
||||
seed_images_dir=SEED_IMAGES_DIR,
|
||||
pipeline=PIPELINE,
|
||||
)
|
||||
|
||||
# Log the total time
|
||||
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
||||
|
@ -153,60 +115,73 @@ def run_inference():
|
|||
return response
|
||||
|
||||
|
||||
def compute(inputs: InferenceInput) -> str:
|
||||
def compute_request(
|
||||
inputs: InferenceInput,
|
||||
pipeline: RiffusionPipeline,
|
||||
seed_images_dir: str,
|
||||
) -> T.Union[str, T.Tuple[str, int]]:
|
||||
"""
|
||||
Does all the heavy lifting of the request.
|
||||
|
||||
Args:
|
||||
inputs: The input dataclass
|
||||
pipeline: The riffusion model pipeline
|
||||
seed_images_dir: The directory where seed images are stored
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(SEED_IMAGES_DIR, f"{inputs.seed_image_id}.png")
|
||||
init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png")
|
||||
|
||||
if not init_image_path.is_file():
|
||||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Load the mask image by ID
|
||||
mask_image: T.Optional[PIL.Image.Image] = None
|
||||
if inputs.mask_image_id:
|
||||
mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png")
|
||||
mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png")
|
||||
if not mask_image_path.is_file():
|
||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||
else:
|
||||
mask_image = None
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
||||
image = pipeline.riffuse(
|
||||
inputs,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
||||
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
||||
# TODO(hayk): It may help performance to cache this object
|
||||
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
|
||||
segment = converter.audio_from_spectrogram_image(
|
||||
image,
|
||||
apply_filters=True,
|
||||
)
|
||||
|
||||
# Compute the output as base64 encoded strings
|
||||
image_bytes = image_bytes_from_image(image, mode="JPEG")
|
||||
# Export audio to MP3 bytes
|
||||
mp3_bytes = io.BytesIO()
|
||||
segment.export(mp3_bytes, format="mp3")
|
||||
mp3_bytes.seek(0)
|
||||
|
||||
# Export image to JPEG bytes
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, exif=image.getexif(), format="JPEG")
|
||||
image_bytes.seek(0)
|
||||
|
||||
# Assemble the output dataclass
|
||||
output = InferenceOutput(
|
||||
image="data:image/jpeg;base64," + base64_encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + base64_encode(mp3_bytes),
|
||||
duration_s=duration_s,
|
||||
image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
|
||||
duration_s=segment.duration_seconds,
|
||||
)
|
||||
|
||||
return flask.jsonify(dataclasses.asdict(output))
|
||||
|
||||
|
||||
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
||||
"""
|
||||
Convert a PIL image into bytes of the given image format.
|
||||
"""
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, mode)
|
||||
image_bytes.seek(0)
|
||||
return image_bytes
|
||||
|
||||
|
||||
def base64_encode(buffer: io.BytesIO) -> str:
|
||||
"""
|
||||
Encode the given buffer as base64.
|
||||
"""
|
||||
return base64.encodebytes(buffer.getvalue()).decode("ascii")
|
||||
return json.dumps(dataclasses.asdict(output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue