Support custom checkpoints in text_to_audio

Also have a toggle for 20khz spectrogram mode

Topic: playground_custom_checkpoints_text_to_audio
This commit is contained in:
Hayk Martiros 2023-01-29 22:00:09 +00:00
parent 8102bc3017
commit 38cce7ab00
2 changed files with 54 additions and 12 deletions

View File

@ -28,6 +28,7 @@ def render_text_to_audio() -> None:
device = streamlit_util.select_device(st.sidebar) device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar) extension = streamlit_util.select_audio_extension(st.sidebar)
checkpoint = streamlit_util.select_checkpoint(st.sidebar)
with st.form("Inputs"): with st.form("Inputs"):
prompt = st.text_input("Prompt") prompt = st.text_input("Prompt")
@ -69,15 +70,25 @@ def render_text_to_audio() -> None:
) )
assert scheduler is not None assert scheduler is not None
use_20k = st.checkbox("Use 20kHz", value=False)
if not prompt: if not prompt:
st.info("Enter a prompt") st.info("Enter a prompt")
return return
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained if use_20k:
params = SpectrogramParams( params = SpectrogramParams(
min_frequency=0, min_frequency=10,
max_frequency=10000, max_frequency=20000,
) sample_rate=44100,
stereo=True,
)
else:
params = SpectrogramParams(
min_frequency=0,
max_frequency=10000,
stereo=False,
)
seed = starting_seed seed = starting_seed
for i in range(1, num_clips + 1): for i in range(1, num_clips + 1):
@ -91,6 +102,7 @@ def render_text_to_audio() -> None:
seed=seed, seed=seed,
width=width, width=width,
height=512, height=512,
checkpoint=checkpoint,
device=device, device=device,
scheduler=scheduler, scheduler=scheduler,
) )

View File

@ -18,6 +18,8 @@ from riffusion.spectrogram_params import SpectrogramParams
# TODO(hayk): Add URL params # TODO(hayk): Add URL params
DEFAULT_CHECKPOINT = "riffusion/riffusion-model-v1"
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"] AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"] IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
@ -33,7 +35,7 @@ SCHEDULER_OPTIONS = [
@st.experimental_singleton @st.experimental_singleton
def load_riffusion_checkpoint( def load_riffusion_checkpoint(
checkpoint: str = "riffusion/riffusion-model-v1", checkpoint: str = DEFAULT_CHECKPOINT,
no_traced_unet: bool = False, no_traced_unet: bool = False,
device: str = "cuda", device: str = "cuda",
) -> RiffusionPipeline: ) -> RiffusionPipeline:
@ -49,7 +51,7 @@ def load_riffusion_checkpoint(
@st.experimental_singleton @st.experimental_singleton
def load_stable_diffusion_pipeline( def load_stable_diffusion_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1", checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0], scheduler: str = SCHEDULER_OPTIONS[0],
@ -117,7 +119,7 @@ def pipeline_lock() -> threading.Lock:
@st.experimental_singleton @st.experimental_singleton
def load_stable_diffusion_img2img_pipeline( def load_stable_diffusion_img2img_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1", checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0], scheduler: str = SCHEDULER_OPTIONS[0],
@ -152,6 +154,7 @@ def run_txt2img(
seed: int, seed: int,
width: int, width: int,
height: int, height: int,
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda", device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0], scheduler: str = SCHEDULER_OPTIONS[0],
) -> Image.Image: ) -> Image.Image:
@ -159,7 +162,11 @@ def run_txt2img(
Run the text to image pipeline with caching. Run the text to image pipeline with caching.
""" """
with pipeline_lock(): with pipeline_lock():
pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler) pipeline = load_stable_diffusion_pipeline(
checkpoint=checkpoint,
device=device,
scheduler=scheduler,
)
generator_device = "cpu" if device.lower().startswith("mps") else device generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed) generator = torch.Generator(device=generator_device).manual_seed(seed)
@ -270,6 +277,18 @@ def select_scheduler(container: T.Any = st.sidebar) -> str:
return scheduler return scheduler
def select_checkpoint(container: T.Any = st.sidebar) -> str:
"""
Provide a custom model checkpoint.
"""
custom_checkpoint = container.text_input(
"Custom Checkpoint",
value="",
help="Provide a custom model checkpoint",
)
return custom_checkpoint or DEFAULT_CHECKPOINT
@st.experimental_memo @st.experimental_memo
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment: def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file) return pydub.AudioSegment.from_file(audio_file)
@ -281,9 +300,13 @@ def get_audio_splitter(device: str = "cuda"):
@st.experimental_singleton @st.experimental_singleton
def load_magic_mix_pipeline(device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0]): def load_magic_mix_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
):
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1", checkpoint,
custom_pipeline="magic_mix", custom_pipeline="magic_mix",
).to(device) ).to(device)
@ -302,6 +325,7 @@ def run_img2img_magic_mix(
kmin: float, kmin: float,
kmax: float, kmax: float,
mix_factor: float, mix_factor: float,
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda", device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0], scheduler: str = SCHEDULER_OPTIONS[0],
): ):
@ -310,6 +334,7 @@ def run_img2img_magic_mix(
""" """
with pipeline_lock(): with pipeline_lock():
pipeline = load_magic_mix_pipeline( pipeline = load_magic_mix_pipeline(
checkpoint=checkpoint,
device=device, device=device,
scheduler=scheduler, scheduler=scheduler,
) )
@ -335,12 +360,17 @@ def run_img2img(
guidance_scale: float, guidance_scale: float,
seed: int, seed: int,
negative_prompt: T.Optional[str] = None, negative_prompt: T.Optional[str] = None,
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda", device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0], scheduler: str = SCHEDULER_OPTIONS[0],
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None, progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
) -> Image.Image: ) -> Image.Image:
with pipeline_lock(): with pipeline_lock():
pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler) pipeline = load_stable_diffusion_img2img_pipeline(
checkpoint=checkpoint,
device=device,
scheduler=scheduler,
)
generator_device = "cpu" if device.lower().startswith("mps") else device generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed) generator = torch.Generator(device=generator_device).manual_seed(seed)