Provide clip duration and encode base64 prefix type
This commit is contained in:
parent
7f27705f81
commit
511defae99
|
@ -2,6 +2,7 @@
|
||||||
Audio processing tools to convert between spectrogram images and waveforms.
|
Audio processing tools to convert between spectrogram images and waveforms.
|
||||||
"""
|
"""
|
||||||
import io
|
import io
|
||||||
|
import typing as T
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -11,9 +12,9 @@ import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
|
def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
|
||||||
"""
|
"""
|
||||||
Reconstruct a WAV audio clip from a spectrogram image.
|
Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_volume = 50
|
max_volume = 50
|
||||||
|
@ -37,7 +38,7 @@ def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
|
||||||
hop_length = int(step_size_ms / 1000.0 * sample_rate)
|
hop_length = int(step_size_ms / 1000.0 * sample_rate)
|
||||||
win_length = int(window_duration_ms / 1000.0 * sample_rate)
|
win_length = int(window_duration_ms / 1000.0 * sample_rate)
|
||||||
|
|
||||||
waveform = waveform_from_spectrogram(
|
samples = waveform_from_spectrogram(
|
||||||
Sxx=Sxx,
|
Sxx=Sxx,
|
||||||
n_fft=n_fft,
|
n_fft=n_fft,
|
||||||
hop_length=hop_length,
|
hop_length=hop_length,
|
||||||
|
@ -51,10 +52,12 @@ def wav_bytes_from_spectrogram_image(image: Image.Image) -> io.BytesIO:
|
||||||
)
|
)
|
||||||
|
|
||||||
wav_bytes = io.BytesIO()
|
wav_bytes = io.BytesIO()
|
||||||
wavfile.write(wav_bytes, sample_rate, waveform.astype(np.int16))
|
wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
|
||||||
wav_bytes.seek(0)
|
wav_bytes.seek(0)
|
||||||
|
|
||||||
return wav_bytes
|
duration_s = float(len(samples)) / sample_rate
|
||||||
|
|
||||||
|
return wav_bytes, duration_s
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_from_image(
|
def spectrogram_from_image(
|
||||||
|
|
|
@ -56,9 +56,13 @@ class InferenceInput:
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceOutput:
|
class InferenceOutput:
|
||||||
"""
|
"""
|
||||||
Response from the model server. Contains a base64 encoded spectrogram image and a base64
|
Response from the model inference server.
|
||||||
encoded MP3 audio clip.
|
|
||||||
"""
|
"""
|
||||||
|
# base64 encoded spectrogram image as a JPEG
|
||||||
image: str
|
image: str
|
||||||
|
|
||||||
|
# base64 encoded audio clip as an MP3
|
||||||
audio: str
|
audio: str
|
||||||
|
|
||||||
|
# The duration of the audio clip
|
||||||
|
duration_s: float
|
||||||
|
|
|
@ -45,7 +45,7 @@ def run_app(
|
||||||
port: int = 3000,
|
port: int = 3000,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
ssl_certificate: T.Optional[str] = None,
|
ssl_certificate: T.Optional[str] = None,
|
||||||
ssl_key: T.Optional[str] = None
|
ssl_key: T.Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run a flask API that serves the given riffusion model checkpoint.
|
Run a flask API that serves the given riffusion model checkpoint.
|
||||||
|
@ -135,12 +135,18 @@ def run_inference():
|
||||||
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image)
|
||||||
|
|
||||||
# Reconstruct audio from the image
|
# Reconstruct audio from the image
|
||||||
wav_bytes = wav_bytes_from_spectrogram_image(image)
|
wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)
|
||||||
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
mp3_bytes = mp3_bytes_from_wav_bytes(wav_bytes)
|
||||||
|
|
||||||
# Compute the output as base64 encoded strings
|
# Compute the output as base64 encoded strings
|
||||||
image_bytes = image_bytes_from_image(image, mode="JPEG")
|
image_bytes = image_bytes_from_image(image, mode="JPEG")
|
||||||
output = InferenceOutput(image=base64_encode(image_bytes), audio=base64_encode(mp3_bytes))
|
|
||||||
|
# Assemble the output dataclass
|
||||||
|
output = InferenceOutput(
|
||||||
|
image="data:image/jpeg;base64," + base64_encode(image_bytes),
|
||||||
|
audio="data:audio/mpeg;base64," + base64_encode(mp3_bytes),
|
||||||
|
duration_s=duration_s,
|
||||||
|
)
|
||||||
|
|
||||||
# Log the total time
|
# Log the total time
|
||||||
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
||||||
|
|
Loading…
Reference in New Issue