437 lines
12 KiB
Python
437 lines
12 KiB
Python
"""
|
|
Streamlit utilities (mostly cached wrappers around riffusion code).
|
|
"""
|
|
import io
|
|
import threading
|
|
import typing as T
|
|
|
|
import pydub
|
|
import streamlit as st
|
|
import torch
|
|
from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
|
|
from PIL import Image
|
|
|
|
from riffusion.audio_splitter import AudioSplitter
|
|
from riffusion.riffusion_pipeline import RiffusionPipeline
|
|
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
|
from riffusion.spectrogram_params import SpectrogramParams
|
|
|
|
# TODO(hayk): Add URL params
|
|
|
|
DEFAULT_CHECKPOINT = "riffusion/riffusion-model-v1"
|
|
|
|
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
|
|
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
|
|
|
|
SCHEDULER_OPTIONS = [
|
|
"DPMSolverMultistepScheduler",
|
|
"PNDMScheduler",
|
|
"DDIMScheduler",
|
|
"LMSDiscreteScheduler",
|
|
"EulerDiscreteScheduler",
|
|
"EulerAncestralDiscreteScheduler",
|
|
]
|
|
|
|
|
|
@st.cache_resource
|
|
def load_riffusion_checkpoint(
|
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
|
no_traced_unet: bool = False,
|
|
device: str = "cuda",
|
|
) -> RiffusionPipeline:
|
|
"""
|
|
Load the riffusion pipeline.
|
|
"""
|
|
return RiffusionPipeline.load_checkpoint(
|
|
checkpoint=checkpoint,
|
|
use_traced_unet=not no_traced_unet,
|
|
device=device,
|
|
)
|
|
|
|
|
|
@st.cache_resource
|
|
def load_stable_diffusion_pipeline(
|
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.float16,
|
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
|
) -> StableDiffusionPipeline:
|
|
"""
|
|
Load the riffusion pipeline.
|
|
|
|
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
|
|
"""
|
|
if device == "cpu" or device.lower().startswith("mps"):
|
|
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
|
dtype = torch.float32
|
|
|
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
|
checkpoint,
|
|
revision="main",
|
|
torch_dtype=dtype,
|
|
safety_checker=lambda images, **kwargs: (images, False),
|
|
).to(device)
|
|
|
|
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
|
|
|
|
return pipeline
|
|
|
|
|
|
def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
|
|
"""
|
|
Construct a denoising scheduler from a string.
|
|
"""
|
|
if scheduler == "PNDMScheduler":
|
|
from diffusers import PNDMScheduler
|
|
|
|
return PNDMScheduler.from_config(config)
|
|
elif scheduler == "DPMSolverMultistepScheduler":
|
|
from diffusers import DPMSolverMultistepScheduler
|
|
|
|
return DPMSolverMultistepScheduler.from_config(config)
|
|
elif scheduler == "DDIMScheduler":
|
|
from diffusers import DDIMScheduler
|
|
|
|
return DDIMScheduler.from_config(config)
|
|
elif scheduler == "LMSDiscreteScheduler":
|
|
from diffusers import LMSDiscreteScheduler
|
|
|
|
return LMSDiscreteScheduler.from_config(config)
|
|
elif scheduler == "EulerDiscreteScheduler":
|
|
from diffusers import EulerDiscreteScheduler
|
|
|
|
return EulerDiscreteScheduler.from_config(config)
|
|
elif scheduler == "EulerAncestralDiscreteScheduler":
|
|
from diffusers import EulerAncestralDiscreteScheduler
|
|
|
|
return EulerAncestralDiscreteScheduler.from_config(config)
|
|
else:
|
|
raise ValueError(f"Unknown scheduler {scheduler}")
|
|
|
|
|
|
@st.cache_resource
|
|
def pipeline_lock() -> threading.Lock:
|
|
"""
|
|
Singleton lock used to prevent concurrent access to any model pipeline.
|
|
"""
|
|
return threading.Lock()
|
|
|
|
|
|
@st.cache_resource
|
|
def load_stable_diffusion_img2img_pipeline(
|
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.float16,
|
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
|
) -> StableDiffusionImg2ImgPipeline:
|
|
"""
|
|
Load the image to image pipeline.
|
|
|
|
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
|
|
"""
|
|
if device == "cpu" or device.lower().startswith("mps"):
|
|
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
|
dtype = torch.float32
|
|
|
|
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
|
|
checkpoint,
|
|
revision="main",
|
|
torch_dtype=dtype,
|
|
safety_checker=lambda images, **kwargs: (images, False),
|
|
).to(device)
|
|
|
|
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
|
|
|
|
return pipeline
|
|
|
|
|
|
@st.cache_data(persist=True)
|
|
def run_txt2img(
|
|
prompt: str,
|
|
num_inference_steps: int,
|
|
guidance: float,
|
|
negative_prompt: str,
|
|
seed: int,
|
|
width: int,
|
|
height: int,
|
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
|
device: str = "cuda",
|
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
|
) -> Image.Image:
|
|
"""
|
|
Run the text to image pipeline with caching.
|
|
"""
|
|
with pipeline_lock():
|
|
pipeline = load_stable_diffusion_pipeline(
|
|
checkpoint=checkpoint,
|
|
device=device,
|
|
scheduler=scheduler,
|
|
)
|
|
|
|
generator_device = "cpu" if device.lower().startswith("mps") else device
|
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
|
|
|
output = pipeline(
|
|
prompt=prompt,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance,
|
|
negative_prompt=negative_prompt or None,
|
|
generator=generator,
|
|
width=width,
|
|
height=height,
|
|
)
|
|
|
|
return output["images"][0]
|
|
|
|
|
|
@st.cache_resource
|
|
def spectrogram_image_converter(
|
|
params: SpectrogramParams,
|
|
device: str = "cuda",
|
|
) -> SpectrogramImageConverter:
|
|
return SpectrogramImageConverter(params=params, device=device)
|
|
|
|
|
|
@st.cache
|
|
def spectrogram_image_from_audio(
|
|
segment: pydub.AudioSegment,
|
|
params: SpectrogramParams,
|
|
device: str = "cuda",
|
|
) -> Image.Image:
|
|
converter = spectrogram_image_converter(params=params, device=device)
|
|
return converter.spectrogram_image_from_audio(segment)
|
|
|
|
|
|
@st.cache_data
|
|
def audio_segment_from_spectrogram_image(
|
|
image: Image.Image,
|
|
params: SpectrogramParams,
|
|
device: str = "cuda",
|
|
) -> pydub.AudioSegment:
|
|
converter = spectrogram_image_converter(params=params, device=device)
|
|
return converter.audio_from_spectrogram_image(image)
|
|
|
|
|
|
@st.cache_data
|
|
def audio_bytes_from_spectrogram_image(
|
|
image: Image.Image,
|
|
params: SpectrogramParams,
|
|
device: str = "cuda",
|
|
output_format: str = "mp3",
|
|
) -> io.BytesIO:
|
|
segment = audio_segment_from_spectrogram_image(image=image, params=params, device=device)
|
|
|
|
audio_bytes = io.BytesIO()
|
|
segment.export(audio_bytes, format=output_format)
|
|
|
|
return audio_bytes
|
|
|
|
|
|
def select_device(container: T.Any = st.sidebar) -> str:
|
|
"""
|
|
Dropdown to select a torch device, with an intelligent default.
|
|
"""
|
|
default_device = "cpu"
|
|
if torch.cuda.is_available():
|
|
default_device = "cuda"
|
|
elif torch.backends.mps.is_available():
|
|
default_device = "mps"
|
|
|
|
device_options = ["cuda", "cpu", "mps"]
|
|
device = st.sidebar.selectbox(
|
|
"Device",
|
|
options=device_options,
|
|
index=device_options.index(default_device),
|
|
help="Which compute device to use. CUDA is recommended.",
|
|
)
|
|
assert device is not None
|
|
|
|
return device
|
|
|
|
|
|
def select_audio_extension(container: T.Any = st.sidebar) -> str:
|
|
"""
|
|
Dropdown to select an audio extension, with an intelligent default.
|
|
"""
|
|
default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
|
|
extension = container.selectbox(
|
|
"Output format",
|
|
options=AUDIO_EXTENSIONS,
|
|
index=AUDIO_EXTENSIONS.index(default),
|
|
)
|
|
assert extension is not None
|
|
return extension
|
|
|
|
|
|
def select_scheduler(container: T.Any = st.sidebar) -> str:
|
|
"""
|
|
Dropdown to select a scheduler.
|
|
"""
|
|
scheduler = st.sidebar.selectbox(
|
|
"Scheduler",
|
|
options=SCHEDULER_OPTIONS,
|
|
index=0,
|
|
help="Which diffusion scheduler to use",
|
|
)
|
|
assert scheduler is not None
|
|
return scheduler
|
|
|
|
|
|
def select_checkpoint(container: T.Any = st.sidebar) -> str:
|
|
"""
|
|
Provide a custom model checkpoint.
|
|
"""
|
|
return container.text_input(
|
|
"Custom Checkpoint",
|
|
value=DEFAULT_CHECKPOINT,
|
|
help="Provide a custom model checkpoint",
|
|
)
|
|
|
|
|
|
@st.cache_data
|
|
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
|
return pydub.AudioSegment.from_file(audio_file)
|
|
|
|
|
|
@st.cache_resource
|
|
def get_audio_splitter(device: str = "cuda"):
|
|
return AudioSplitter(device=device)
|
|
|
|
|
|
@st.cache_resource
|
|
def load_magic_mix_pipeline(
|
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
|
device: str = "cuda",
|
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
|
):
|
|
pipeline = DiffusionPipeline.from_pretrained(
|
|
checkpoint,
|
|
custom_pipeline="magic_mix",
|
|
).to(device)
|
|
|
|
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
|
|
|
|
return pipeline
|
|
|
|
|
|
@st.cache
|
|
def run_img2img_magic_mix(
|
|
prompt: str,
|
|
init_image: Image.Image,
|
|
num_inference_steps: int,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
kmin: float,
|
|
kmax: float,
|
|
mix_factor: float,
|
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
|
device: str = "cuda",
|
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
|
):
|
|
"""
|
|
Run the magic mix pipeline for img2img.
|
|
"""
|
|
with pipeline_lock():
|
|
pipeline = load_magic_mix_pipeline(
|
|
checkpoint=checkpoint,
|
|
device=device,
|
|
scheduler=scheduler,
|
|
)
|
|
|
|
return pipeline(
|
|
init_image,
|
|
prompt=prompt,
|
|
kmin=kmin,
|
|
kmax=kmax,
|
|
mix_factor=mix_factor,
|
|
seed=seed,
|
|
guidance_scale=guidance_scale,
|
|
steps=num_inference_steps,
|
|
)
|
|
|
|
|
|
@st.cache
|
|
def run_img2img(
|
|
prompt: str,
|
|
init_image: Image.Image,
|
|
denoising_strength: float,
|
|
num_inference_steps: int,
|
|
guidance_scale: float,
|
|
seed: int,
|
|
negative_prompt: T.Optional[str] = None,
|
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
|
device: str = "cuda",
|
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
|
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
|
|
) -> Image.Image:
|
|
with pipeline_lock():
|
|
pipeline = load_stable_diffusion_img2img_pipeline(
|
|
checkpoint=checkpoint,
|
|
device=device,
|
|
scheduler=scheduler,
|
|
)
|
|
|
|
generator_device = "cpu" if device.lower().startswith("mps") else device
|
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
|
|
|
num_expected_steps = max(int(num_inference_steps * denoising_strength), 1)
|
|
|
|
def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None:
|
|
if progress_callback is not None:
|
|
progress_callback(step / num_expected_steps)
|
|
|
|
result = pipeline(
|
|
prompt=prompt,
|
|
image=init_image,
|
|
strength=denoising_strength,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
negative_prompt=negative_prompt or None,
|
|
num_images_per_prompt=1,
|
|
generator=generator,
|
|
callback=callback,
|
|
callback_steps=1,
|
|
)
|
|
|
|
return result.images[0]
|
|
|
|
|
|
class StreamlitCounter:
|
|
"""
|
|
Simple counter stored in streamlit session state.
|
|
"""
|
|
|
|
def __init__(self, key="_counter"):
|
|
self.key = key
|
|
if not st.session_state.get(self.key):
|
|
st.session_state[self.key] = 0
|
|
|
|
def increment(self):
|
|
st.session_state[self.key] += 1
|
|
|
|
@property
|
|
def value(self):
|
|
return st.session_state[self.key]
|
|
|
|
|
|
def display_and_download_audio(
|
|
segment: pydub.AudioSegment,
|
|
name: str,
|
|
extension: str = "mp3",
|
|
) -> None:
|
|
"""
|
|
Display the given audio segment and provide a button to download it with
|
|
a proper file name, since st.audio doesn't support that.
|
|
"""
|
|
mime_type = f"audio/{extension}"
|
|
audio_bytes = io.BytesIO()
|
|
segment.export(audio_bytes, format=extension)
|
|
st.audio(audio_bytes, format=mime_type)
|
|
|
|
st.download_button(
|
|
f"{name}.{extension}",
|
|
data=audio_bytes,
|
|
file_name=f"{name}.{extension}",
|
|
mime=mime_type,
|
|
)
|