riffusion-inference/riffusion/streamlit/util.py

437 lines
12 KiB
Python

"""
Streamlit utilities (mostly cached wrappers around riffusion code).
"""
import io
import threading
import typing as T
import pydub
import streamlit as st
import torch
from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from PIL import Image
from riffusion.audio_splitter import AudioSplitter
from riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
# TODO(hayk): Add URL params
DEFAULT_CHECKPOINT = "riffusion/riffusion-model-v1"
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
SCHEDULER_OPTIONS = [
"DPMSolverMultistepScheduler",
"PNDMScheduler",
"DDIMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@st.cache_resource
def load_riffusion_checkpoint(
checkpoint: str = DEFAULT_CHECKPOINT,
no_traced_unet: bool = False,
device: str = "cuda",
) -> RiffusionPipeline:
"""
Load the riffusion pipeline.
"""
return RiffusionPipeline.load_checkpoint(
checkpoint=checkpoint,
use_traced_unet=not no_traced_unet,
device=device,
)
@st.cache_resource
def load_stable_diffusion_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0],
) -> StableDiffusionPipeline:
"""
Load the riffusion pipeline.
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
"""
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
pipeline = StableDiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
return pipeline
def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
"""
Construct a denoising scheduler from a string.
"""
if scheduler == "PNDMScheduler":
from diffusers import PNDMScheduler
return PNDMScheduler.from_config(config)
elif scheduler == "DPMSolverMultistepScheduler":
from diffusers import DPMSolverMultistepScheduler
return DPMSolverMultistepScheduler.from_config(config)
elif scheduler == "DDIMScheduler":
from diffusers import DDIMScheduler
return DDIMScheduler.from_config(config)
elif scheduler == "LMSDiscreteScheduler":
from diffusers import LMSDiscreteScheduler
return LMSDiscreteScheduler.from_config(config)
elif scheduler == "EulerDiscreteScheduler":
from diffusers import EulerDiscreteScheduler
return EulerDiscreteScheduler.from_config(config)
elif scheduler == "EulerAncestralDiscreteScheduler":
from diffusers import EulerAncestralDiscreteScheduler
return EulerAncestralDiscreteScheduler.from_config(config)
else:
raise ValueError(f"Unknown scheduler {scheduler}")
@st.cache_resource
def pipeline_lock() -> threading.Lock:
"""
Singleton lock used to prevent concurrent access to any model pipeline.
"""
return threading.Lock()
@st.cache_resource
def load_stable_diffusion_img2img_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0],
) -> StableDiffusionImg2ImgPipeline:
"""
Load the image to image pipeline.
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
"""
if device == "cpu" or device.lower().startswith("mps"):
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
return pipeline
@st.cache_data(persist=True)
def run_txt2img(
prompt: str,
num_inference_steps: int,
guidance: float,
negative_prompt: str,
seed: int,
width: int,
height: int,
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
) -> Image.Image:
"""
Run the text to image pipeline with caching.
"""
with pipeline_lock():
pipeline = load_stable_diffusion_pipeline(
checkpoint=checkpoint,
device=device,
scheduler=scheduler,
)
generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
output = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=negative_prompt or None,
generator=generator,
width=width,
height=height,
)
return output["images"][0]
@st.cache_resource
def spectrogram_image_converter(
params: SpectrogramParams,
device: str = "cuda",
) -> SpectrogramImageConverter:
return SpectrogramImageConverter(params=params, device=device)
@st.cache
def spectrogram_image_from_audio(
segment: pydub.AudioSegment,
params: SpectrogramParams,
device: str = "cuda",
) -> Image.Image:
converter = spectrogram_image_converter(params=params, device=device)
return converter.spectrogram_image_from_audio(segment)
@st.cache_data
def audio_segment_from_spectrogram_image(
image: Image.Image,
params: SpectrogramParams,
device: str = "cuda",
) -> pydub.AudioSegment:
converter = spectrogram_image_converter(params=params, device=device)
return converter.audio_from_spectrogram_image(image)
@st.cache_data
def audio_bytes_from_spectrogram_image(
image: Image.Image,
params: SpectrogramParams,
device: str = "cuda",
output_format: str = "mp3",
) -> io.BytesIO:
segment = audio_segment_from_spectrogram_image(image=image, params=params, device=device)
audio_bytes = io.BytesIO()
segment.export(audio_bytes, format=output_format)
return audio_bytes
def select_device(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select a torch device, with an intelligent default.
"""
default_device = "cpu"
if torch.cuda.is_available():
default_device = "cuda"
elif torch.backends.mps.is_available():
default_device = "mps"
device_options = ["cuda", "cpu", "mps"]
device = st.sidebar.selectbox(
"Device",
options=device_options,
index=device_options.index(default_device),
help="Which compute device to use. CUDA is recommended.",
)
assert device is not None
return device
def select_audio_extension(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select an audio extension, with an intelligent default.
"""
default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
extension = container.selectbox(
"Output format",
options=AUDIO_EXTENSIONS,
index=AUDIO_EXTENSIONS.index(default),
)
assert extension is not None
return extension
def select_scheduler(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select a scheduler.
"""
scheduler = st.sidebar.selectbox(
"Scheduler",
options=SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None
return scheduler
def select_checkpoint(container: T.Any = st.sidebar) -> str:
"""
Provide a custom model checkpoint.
"""
return container.text_input(
"Custom Checkpoint",
value=DEFAULT_CHECKPOINT,
help="Provide a custom model checkpoint",
)
@st.cache_data
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file)
@st.cache_resource
def get_audio_splitter(device: str = "cuda"):
return AudioSplitter(device=device)
@st.cache_resource
def load_magic_mix_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
):
pipeline = DiffusionPipeline.from_pretrained(
checkpoint,
custom_pipeline="magic_mix",
).to(device)
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
return pipeline
@st.cache
def run_img2img_magic_mix(
prompt: str,
init_image: Image.Image,
num_inference_steps: int,
guidance_scale: float,
seed: int,
kmin: float,
kmax: float,
mix_factor: float,
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
):
"""
Run the magic mix pipeline for img2img.
"""
with pipeline_lock():
pipeline = load_magic_mix_pipeline(
checkpoint=checkpoint,
device=device,
scheduler=scheduler,
)
return pipeline(
init_image,
prompt=prompt,
kmin=kmin,
kmax=kmax,
mix_factor=mix_factor,
seed=seed,
guidance_scale=guidance_scale,
steps=num_inference_steps,
)
@st.cache
def run_img2img(
prompt: str,
init_image: Image.Image,
denoising_strength: float,
num_inference_steps: int,
guidance_scale: float,
seed: int,
negative_prompt: T.Optional[str] = None,
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
) -> Image.Image:
with pipeline_lock():
pipeline = load_stable_diffusion_img2img_pipeline(
checkpoint=checkpoint,
device=device,
scheduler=scheduler,
)
generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
num_expected_steps = max(int(num_inference_steps * denoising_strength), 1)
def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None:
if progress_callback is not None:
progress_callback(step / num_expected_steps)
result = pipeline(
prompt=prompt,
image=init_image,
strength=denoising_strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt or None,
num_images_per_prompt=1,
generator=generator,
callback=callback,
callback_steps=1,
)
return result.images[0]
class StreamlitCounter:
"""
Simple counter stored in streamlit session state.
"""
def __init__(self, key="_counter"):
self.key = key
if not st.session_state.get(self.key):
st.session_state[self.key] = 0
def increment(self):
st.session_state[self.key] += 1
@property
def value(self):
return st.session_state[self.key]
def display_and_download_audio(
segment: pydub.AudioSegment,
name: str,
extension: str = "mp3",
) -> None:
"""
Display the given audio segment and provide a button to download it with
a proper file name, since st.audio doesn't support that.
"""
mime_type = f"audio/{extension}"
audio_bytes = io.BytesIO()
segment.export(audio_bytes, format=extension)
st.audio(audio_bytes, format=mime_type)
st.download_button(
f"{name}.{extension}",
data=audio_bytes,
file_name=f"{name}.{extension}",
mime=mime_type,
)