parent
39dc247a1d
commit
420674148a
|
@ -1,11 +1,7 @@
|
||||||
import io
|
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
|
||||||
from riffusion.spectrogram_params import SpectrogramParams
|
|
||||||
from riffusion.streamlit import util as streamlit_util
|
|
||||||
from riffusion.util.image_util import exif_from_image
|
from riffusion.util.image_util import exif_from_image
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,13 +22,13 @@ def render_image_to_audio() -> None:
|
||||||
st.write("Exif data:")
|
st.write("Exif data:")
|
||||||
st.write(exif)
|
st.write(exif)
|
||||||
|
|
||||||
device = "cuda"
|
# device = "cuda"
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
params = SpectrogramParams.from_exif(exif=image.getexif())
|
# params = SpectrogramParams.from_exif(exif=image.getexif())
|
||||||
except KeyError:
|
# except KeyError:
|
||||||
st.warning("Could not find spectrogram parameters in exif data. Using defaults.")
|
# st.warning("Could not find spectrogram parameters in exif data. Using defaults.")
|
||||||
params = SpectrogramParams()
|
# params = SpectrogramParams()
|
||||||
|
|
||||||
# segment = streamlit_util.audio_from_spectrogram_image(
|
# segment = streamlit_util.audio_from_spectrogram_image(
|
||||||
# image=image,
|
# image=image,
|
||||||
|
|
|
@ -1,101 +1,32 @@
|
||||||
import io
|
import typing as T
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import dacite
|
|
||||||
from diffusers import StableDiffusionPipeline
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from riffusion.datatypes import InferenceInput
|
|
||||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
|
||||||
from riffusion.spectrogram_params import SpectrogramParams
|
from riffusion.spectrogram_params import SpectrogramParams
|
||||||
from riffusion.streamlit import util as streamlit_util
|
from riffusion.streamlit import util as streamlit_util
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
|
||||||
def load_stable_diffusion_pipeline(
|
|
||||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype = torch.float16,
|
|
||||||
) -> StableDiffusionPipeline:
|
|
||||||
"""
|
|
||||||
Load the riffusion pipeline.
|
|
||||||
"""
|
|
||||||
if device == "cpu" or device.lower().startswith("mps"):
|
|
||||||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
return StableDiffusionPipeline.from_pretrained(
|
|
||||||
checkpoint,
|
|
||||||
revision="main",
|
|
||||||
torch_dtype=dtype,
|
|
||||||
safety_checker=lambda images, **kwargs: (images, False),
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
|
||||||
def run_txt2img(
|
|
||||||
prompt: str,
|
|
||||||
num_inference_steps: int,
|
|
||||||
guidance: float,
|
|
||||||
negative_prompt: str,
|
|
||||||
seed: int,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
device: str = "cuda",
|
|
||||||
) -> Image.Image:
|
|
||||||
"""
|
|
||||||
Run the text to image pipeline with caching.
|
|
||||||
"""
|
|
||||||
pipeline = load_stable_diffusion_pipeline(device=device)
|
|
||||||
|
|
||||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
|
||||||
|
|
||||||
output = pipeline(
|
|
||||||
prompt=prompt,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
guidance_scale=guidance,
|
|
||||||
negative_prompt=negative_prompt or None,
|
|
||||||
generator=generator,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output["images"][0]
|
|
||||||
|
|
||||||
|
|
||||||
def render_text_to_audio() -> None:
|
def render_text_to_audio() -> None:
|
||||||
"""
|
"""
|
||||||
Render audio from text.
|
Render audio from text.
|
||||||
"""
|
"""
|
||||||
prompt = st.text_input("Prompt")
|
prompt = st.text_input("Prompt")
|
||||||
if not prompt:
|
|
||||||
st.info("Enter a prompt")
|
|
||||||
return
|
|
||||||
|
|
||||||
negative_prompt = st.text_input("Negative prompt")
|
negative_prompt = st.text_input("Negative prompt")
|
||||||
seed = st.sidebar.number_input("Seed", value=42)
|
seed = T.cast(int, st.sidebar.number_input("Seed", value=42))
|
||||||
num_inference_steps = st.sidebar.number_input("Inference steps", value=20)
|
num_inference_steps = T.cast(int, st.sidebar.number_input("Inference steps", value=50))
|
||||||
width = st.sidebar.number_input("Width", value=512)
|
width = T.cast(int, st.sidebar.number_input("Width", value=512))
|
||||||
height = st.sidebar.number_input("Height", value=512)
|
height = T.cast(int, st.sidebar.number_input("Height", value=512))
|
||||||
guidance = st.sidebar.number_input(
|
guidance = st.sidebar.number_input(
|
||||||
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
default_device = "cpu"
|
if not prompt:
|
||||||
if torch.cuda.is_available():
|
st.info("Enter a prompt")
|
||||||
default_device = "cuda"
|
return
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
default_device = "mps"
|
|
||||||
|
|
||||||
device_options = ["cuda", "cpu", "mps"]
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
device = st.sidebar.selectbox(
|
|
||||||
"Device", options=device_options, index=device_options.index(default_device)
|
|
||||||
)
|
|
||||||
assert device is not None
|
|
||||||
|
|
||||||
image = run_txt2img(
|
image = streamlit_util.run_txt2img(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
guidance=guidance,
|
guidance=guidance,
|
||||||
|
@ -105,7 +36,6 @@ def render_text_to_audio() -> None:
|
||||||
height=height,
|
height=height,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
st.image(image)
|
st.image(image)
|
||||||
|
|
||||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||||
|
@ -114,16 +44,13 @@ def render_text_to_audio() -> None:
|
||||||
max_frequency=10000,
|
max_frequency=10000,
|
||||||
)
|
)
|
||||||
|
|
||||||
segment = streamlit_util.audio_from_spectrogram_image(
|
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||||
image=image,
|
image=image,
|
||||||
params=params,
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
|
output_format="mp3",
|
||||||
)
|
)
|
||||||
|
st.audio(audio_bytes)
|
||||||
mp3_bytes = io.BytesIO()
|
|
||||||
segment.export(mp3_bytes, format="mp3")
|
|
||||||
mp3_bytes.seek(0)
|
|
||||||
st.audio(mp3_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
"""
|
"""
|
||||||
Streamlit utilities (mostly cached wrappers around riffusion code).
|
Streamlit utilities (mostly cached wrappers around riffusion code).
|
||||||
"""
|
"""
|
||||||
|
import io
|
||||||
|
import typing as T
|
||||||
|
|
||||||
import pydub
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||||
|
@ -26,6 +29,63 @@ def load_riffusion_checkpoint(
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@st.experimental_singleton
|
||||||
|
def load_stable_diffusion_pipeline(
|
||||||
|
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
) -> StableDiffusionPipeline:
|
||||||
|
"""
|
||||||
|
Load the riffusion pipeline.
|
||||||
|
|
||||||
|
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
|
||||||
|
"""
|
||||||
|
if device == "cpu" or device.lower().startswith("mps"):
|
||||||
|
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
return StableDiffusionPipeline.from_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
revision="main",
|
||||||
|
torch_dtype=dtype,
|
||||||
|
safety_checker=lambda images, **kwargs: (images, False),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@st.experimental_memo
|
||||||
|
def run_txt2img(
|
||||||
|
prompt: str,
|
||||||
|
num_inference_steps: int,
|
||||||
|
guidance: float,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
device: str = "cuda",
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Run the text to image pipeline with caching.
|
||||||
|
"""
|
||||||
|
pipeline = load_stable_diffusion_pipeline(device=device)
|
||||||
|
|
||||||
|
generator_device = "cpu" if device.lower().startswith("mps") else device
|
||||||
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
|
|
||||||
|
output = pipeline(
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
guidance_scale=guidance,
|
||||||
|
negative_prompt=negative_prompt or None,
|
||||||
|
generator=generator,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output["images"][0]
|
||||||
|
|
||||||
|
|
||||||
# class CachedSpectrogramImageConverter:
|
# class CachedSpectrogramImageConverter:
|
||||||
|
|
||||||
# def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
# def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
||||||
|
@ -54,13 +114,39 @@ def spectrogram_image_converter(
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.experimental_memo
|
||||||
def audio_from_spectrogram_image(
|
def audio_bytes_from_spectrogram_image(
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
params: SpectrogramParams,
|
params: SpectrogramParams,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> pydub.AudioSegment:
|
output_format: str = "mp3",
|
||||||
|
) -> io.BytesIO:
|
||||||
converter = spectrogram_image_converter(params=params, device=device)
|
converter = spectrogram_image_converter(params=params, device=device)
|
||||||
return converter.audio_from_spectrogram_image(image)
|
segment = converter.audio_from_spectrogram_image(image)
|
||||||
|
|
||||||
|
audio_bytes = io.BytesIO()
|
||||||
|
segment.export(audio_bytes, format=output_format)
|
||||||
|
audio_bytes.seek(0)
|
||||||
|
|
||||||
|
return audio_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def select_device(container: T.Any = st.sidebar) -> str:
|
||||||
|
"""
|
||||||
|
Dropdown to select a torch device, with an intelligent default.
|
||||||
|
"""
|
||||||
|
default_device = "cpu"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
default_device = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
default_device = "mps"
|
||||||
|
|
||||||
|
device_options = ["cuda", "cpu", "mps"]
|
||||||
|
device = st.sidebar.selectbox(
|
||||||
|
"Device", options=device_options, index=device_options.index(default_device)
|
||||||
|
)
|
||||||
|
assert device is not None
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
# @st.experimental_memo
|
# @st.experimental_memo
|
||||||
|
|
Loading…
Reference in New Issue