Provide clip duration and encode base64 prefix type
This commit is contained in:
parent
7f27705f81
commit
511defae99
|
@ -2,6 +2,7 @@
|
|||
Audio processing tools to convert between spectrogram images and waveforms.
|
||||
"""
|
||||
import io
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -11,9 +12,9 @@ import torch
|
|||
import torchaudio
|
||||
|
||||
|
||||
def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
|
||||
def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
|
||||
"""
|
||||
Reconstruct a WAV audio clip from a spectrogram image.
|
||||
Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
|
||||
"""
|
||||
|
||||
max_volume = 50
|
||||
|
@ -37,7 +38,7 @@ def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
|
|||
hop_length = int(step_size_ms / 1000.0 * sample_rate)
|
||||
win_length = int(window_duration_ms / 1000.0 * sample_rate)
|
||||
|
||||
waveform = waveform_from_spectrogram(
|
||||
samples = waveform_from_spectrogram(
|
||||
Sxx=Sxx,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
|
@ -51,10 +52,12 @@ def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
|
|||
)
|
||||
|
||||
wav_bytes = io.BytesIO()
|
||||
wavfile.write(wav_bytes, sample_rate, waveform.astype(np.int16))
|
||||
wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
|
||||
wav_bytes.seek(0)
|
||||
|
||||
return wav_bytes
|
||||
duration_s = float(len(samples)) / sample_rate
|
||||
|
||||
return wav_bytes, duration_s
|
||||
|
||||
|
||||
def spectrogram_from_image(
|
||||
|
|
|
@ -56,9 +56,13 @@ class InferenceInput:
|
|||
@dataclass
|
||||
class InferenceOutput:
|
||||
"""
|
||||
Response from the model server. Contains a base64 encoded spectrogram image and a base64
|
||||
encoded MP3 audio clip.
|
||||
Response from the model inference server.
|
||||
"""
|
||||
|
||||
# base64 encoded spectrogram image as a JPEG
|
||||
image: str
|
||||
|
||||
# base64 encoded audio clip as an MP3
|
||||
audio: str
|
||||
|
||||
# The duration of the audio clip
|
||||
duration_s: float
|
||||
|
|
|
@ -45,7 +45,7 @@ def run_app(
|
|||
port: int = 3000,
|
||||
debug: bool = False,
|
||||
ssl_certificate: T.Optional[str] = None,
|
||||
ssl_key: T.Optional[str] = None
|
||||
ssl_key: T.Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Run a flask API that serves the given riffusion model checkpoint.
|
||||
|
@ -135,12 +135,18 @@ def run_inference():
|
|||
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
wav_bytes = wav_bytes_from_spectrogram_image(image)
|
||||
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
||||
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
||||
|
||||
# Compute the output as base64 encoded strings
|
||||
image_bytes = image_bytes_from_image(image, mode="JPEG")
|
||||
output = InferenceOutput(image=base64_encode(image_bytes), audio=base64_encode(mp3_bytes))
|
||||
|
||||
# Assemble the output dataclass
|
||||
output = InferenceOutput(
|
||||
image="data:image/jpeg;base64," + base64_encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + base64_encode(mp3_bytes),
|
||||
duration_s=duration_s,
|
||||
)
|
||||
|
||||
# Log the total time
|
||||
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
||||
|
|
Loading…
Reference in New Issue