Provide clip duration and encode base64 prefix type

This commit is contained in:
Hayk Martiros 2022-11-27 21:55:42 +00:00
parent 7f27705f81
commit 511defae99
3 changed files with 24 additions and 11 deletions

View File

@ -2,6 +2,7 @@
Audio processing tools to convert between spectrogram images and waveforms.
"""
import io
import typing as T
import numpy as np
from PIL import Image
@ -11,9 +12,9 @@ import torch
import torchaudio
def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
"""
Reconstruct a WAV audio clip from a spectrogram image.
Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
"""
max_volume = 50
@ -37,7 +38,7 @@ def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
hop_length = int(step_size_ms / 1000.0 * sample_rate)
win_length = int(window_duration_ms / 1000.0 * sample_rate)
waveform = waveform_from_spectrogram(
samples = waveform_from_spectrogram(
Sxx=Sxx,
n_fft=n_fft,
hop_length=hop_length,
@ -51,10 +52,12 @@ def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
)
wav_bytes = io.BytesIO()
wavfile.write(wav_bytes, sample_rate, waveform.astype(np.int16))
wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
wav_bytes.seek(0)
return wav_bytes
duration_s = float(len(samples)) / sample_rate
return wav_bytes, duration_s
def spectrogram_from_image(

View File

@ -56,9 +56,13 @@ class InferenceInput:
@dataclass
class InferenceOutput:
"""
Response from the model server. Contains a base64 encoded spectrogram image and a base64
encoded MP3 audio clip.
Response from the model inference server.
"""
# base64 encoded spectrogram image as a JPEG
image: str
# base64 encoded audio clip as an MP3
audio: str
# The duration of the audio clip
duration_s: float

View File

@ -45,7 +45,7 @@ def run_app(
port: int = 3000,
debug: bool = False,
ssl_certificate: T.Optional[str] = None,
ssl_key: T.Optional[str] = None
ssl_key: T.Optional[str] = None,
):
"""
Run a flask API that serves the given riffusion model checkpoint.
@ -135,12 +135,18 @@ def run_inference():
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
# Reconstruct audio from the image
wav_bytes = wav_bytes_from_spectrogram_image(image)
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
# Compute the output as base64 encoded strings
image_bytes = image_bytes_from_image(image, mode="JPEG")
output = InferenceOutput(image=base64_encode(image_bytes), audio=base64_encode(mp3_bytes))
# Assemble the output dataclass
output = InferenceOutput(
image="data:image/jpeg;base64," + base64_encode(image_bytes),
audio="data:audio/mpeg;base64," + base64_encode(mp3_bytes),
duration_s=duration_s,
)
# Log the total time
logging.info(f"Request took {time.time() - start_time:.2f} s")