Support custom checkpoints in text_to_audio
Also have a toggle for 20khz spectrogram mode Topic: playground_custom_checkpoints_text_to_audio
This commit is contained in:
parent
8102bc3017
commit
38cce7ab00
|
@ -28,6 +28,7 @@ def render_text_to_audio() -> None:
|
||||||
|
|
||||||
device = streamlit_util.select_device(st.sidebar)
|
device = streamlit_util.select_device(st.sidebar)
|
||||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||||
|
checkpoint = streamlit_util.select_checkpoint(st.sidebar)
|
||||||
|
|
||||||
with st.form("Inputs"):
|
with st.form("Inputs"):
|
||||||
prompt = st.text_input("Prompt")
|
prompt = st.text_input("Prompt")
|
||||||
|
@ -69,15 +70,25 @@ def render_text_to_audio() -> None:
|
||||||
)
|
)
|
||||||
assert scheduler is not None
|
assert scheduler is not None
|
||||||
|
|
||||||
|
use_20k = st.checkbox("Use 20kHz", value=False)
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
st.info("Enter a prompt")
|
st.info("Enter a prompt")
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
if use_20k:
|
||||||
params = SpectrogramParams(
|
params = SpectrogramParams(
|
||||||
min_frequency=0,
|
min_frequency=10,
|
||||||
max_frequency=10000,
|
max_frequency=20000,
|
||||||
)
|
sample_rate=44100,
|
||||||
|
stereo=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
params = SpectrogramParams(
|
||||||
|
min_frequency=0,
|
||||||
|
max_frequency=10000,
|
||||||
|
stereo=False,
|
||||||
|
)
|
||||||
|
|
||||||
seed = starting_seed
|
seed = starting_seed
|
||||||
for i in range(1, num_clips + 1):
|
for i in range(1, num_clips + 1):
|
||||||
|
@ -91,6 +102,7 @@ def render_text_to_audio() -> None:
|
||||||
seed=seed,
|
seed=seed,
|
||||||
width=width,
|
width=width,
|
||||||
height=512,
|
height=512,
|
||||||
|
checkpoint=checkpoint,
|
||||||
device=device,
|
device=device,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,6 +18,8 @@ from riffusion.spectrogram_params import SpectrogramParams
|
||||||
|
|
||||||
# TODO(hayk): Add URL params
|
# TODO(hayk): Add URL params
|
||||||
|
|
||||||
|
DEFAULT_CHECKPOINT = "riffusion/riffusion-model-v1"
|
||||||
|
|
||||||
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
|
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
|
||||||
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
|
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
|
||||||
|
|
||||||
|
@ -33,7 +35,7 @@ SCHEDULER_OPTIONS = [
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.experimental_singleton
|
||||||
def load_riffusion_checkpoint(
|
def load_riffusion_checkpoint(
|
||||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
no_traced_unet: bool = False,
|
no_traced_unet: bool = False,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> RiffusionPipeline:
|
) -> RiffusionPipeline:
|
||||||
|
@ -49,7 +51,7 @@ def load_riffusion_checkpoint(
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.experimental_singleton
|
||||||
def load_stable_diffusion_pipeline(
|
def load_stable_diffusion_pipeline(
|
||||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||||
|
@ -117,7 +119,7 @@ def pipeline_lock() -> threading.Lock:
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.experimental_singleton
|
||||||
def load_stable_diffusion_img2img_pipeline(
|
def load_stable_diffusion_img2img_pipeline(
|
||||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||||
|
@ -152,6 +154,7 @@ def run_txt2img(
|
||||||
seed: int,
|
seed: int,
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
@ -159,7 +162,11 @@ def run_txt2img(
|
||||||
Run the text to image pipeline with caching.
|
Run the text to image pipeline with caching.
|
||||||
"""
|
"""
|
||||||
with pipeline_lock():
|
with pipeline_lock():
|
||||||
pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler)
|
pipeline = load_stable_diffusion_pipeline(
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
device=device,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
generator_device = "cpu" if device.lower().startswith("mps") else device
|
generator_device = "cpu" if device.lower().startswith("mps") else device
|
||||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
|
@ -270,6 +277,18 @@ def select_scheduler(container: T.Any = st.sidebar) -> str:
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def select_checkpoint(container: T.Any = st.sidebar) -> str:
|
||||||
|
"""
|
||||||
|
Provide a custom model checkpoint.
|
||||||
|
"""
|
||||||
|
custom_checkpoint = container.text_input(
|
||||||
|
"Custom Checkpoint",
|
||||||
|
value="",
|
||||||
|
help="Provide a custom model checkpoint",
|
||||||
|
)
|
||||||
|
return custom_checkpoint or DEFAULT_CHECKPOINT
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_memo
|
@st.experimental_memo
|
||||||
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
||||||
return pydub.AudioSegment.from_file(audio_file)
|
return pydub.AudioSegment.from_file(audio_file)
|
||||||
|
@ -281,9 +300,13 @@ def get_audio_splitter(device: str = "cuda"):
|
||||||
|
|
||||||
|
|
||||||
@st.experimental_singleton
|
@st.experimental_singleton
|
||||||
def load_magic_mix_pipeline(device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0]):
|
def load_magic_mix_pipeline(
|
||||||
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
|
device: str = "cuda",
|
||||||
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||||
|
):
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
"riffusion/riffusion-model-v1",
|
checkpoint,
|
||||||
custom_pipeline="magic_mix",
|
custom_pipeline="magic_mix",
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
@ -302,6 +325,7 @@ def run_img2img_magic_mix(
|
||||||
kmin: float,
|
kmin: float,
|
||||||
kmax: float,
|
kmax: float,
|
||||||
mix_factor: float,
|
mix_factor: float,
|
||||||
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||||
):
|
):
|
||||||
|
@ -310,6 +334,7 @@ def run_img2img_magic_mix(
|
||||||
"""
|
"""
|
||||||
with pipeline_lock():
|
with pipeline_lock():
|
||||||
pipeline = load_magic_mix_pipeline(
|
pipeline = load_magic_mix_pipeline(
|
||||||
|
checkpoint=checkpoint,
|
||||||
device=device,
|
device=device,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
@ -335,12 +360,17 @@ def run_img2img(
|
||||||
guidance_scale: float,
|
guidance_scale: float,
|
||||||
seed: int,
|
seed: int,
|
||||||
negative_prompt: T.Optional[str] = None,
|
negative_prompt: T.Optional[str] = None,
|
||||||
|
checkpoint: str = DEFAULT_CHECKPOINT,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||||
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
|
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
with pipeline_lock():
|
with pipeline_lock():
|
||||||
pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler)
|
pipeline = load_stable_diffusion_img2img_pipeline(
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
device=device,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
generator_device = "cpu" if device.lower().startswith("mps") else device
|
generator_device = "cpu" if device.lower().startswith("mps") else device
|
||||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||||
|
|
Loading…
Reference in New Issue