diff --git a/riffusion/streamlit/pages/text_to_audio.py b/riffusion/streamlit/pages/text_to_audio.py index 48ccbbf..16aeca2 100644 --- a/riffusion/streamlit/pages/text_to_audio.py +++ b/riffusion/streamlit/pages/text_to_audio.py @@ -28,11 +28,33 @@ def render_text_to_audio() -> None: device = streamlit_util.select_device(st.sidebar) - prompt = st.text_input("Prompt") - negative_prompt = st.text_input("Negative prompt") + with st.form("Inputs"): + prompt = st.text_input("Prompt") + negative_prompt = st.text_input("Negative prompt") - with st.sidebar.expander("Text to Audio Params", expanded=True): - seed = T.cast(int, st.number_input("Seed", value=42)) + row = st.columns(4) + num_clips = T.cast( + int, + row[0].number_input( + "Number of clips", + value=1, + min_value=1, + max_value=25, + help="How many outputs to generate (seed gets incremented)", + ), + ) + starting_seed = T.cast( + int, + row[1].number_input( + "Seed", + value=42, + help="Change this to generate different variations", + ), + ) + + st.form_submit_button("Riff", type="primary") + + with st.sidebar: num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50)) width = T.cast(int, st.number_input("Width", value=512)) guidance = st.number_input( @@ -43,32 +65,37 @@ def render_text_to_audio() -> None: st.info("Enter a prompt") return - image = streamlit_util.run_txt2img( - prompt=prompt, - num_inference_steps=num_inference_steps, - guidance=guidance, - negative_prompt=negative_prompt, - seed=seed, - width=width, - height=512, - device=device, - ) - - st.image(image) - # TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained params = SpectrogramParams( min_frequency=0, max_frequency=10000, ) - audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( - image=image, - params=params, - device=device, - output_format="mp3", - ) - st.audio(audio_bytes) + seed = starting_seed + for i in range(1, num_clips + 1): + st.write(f"#### Riff {i} / {num_clips} - Seed {seed}") + + image = streamlit_util.run_txt2img( + prompt=prompt, + num_inference_steps=num_inference_steps, + guidance=guidance, + negative_prompt=negative_prompt, + seed=seed, + width=width, + height=512, + device=device, + ) + st.image(image) + + audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image( + image=image, + params=params, + device=device, + output_format="mp3", + ) + st.audio(audio_bytes) + + seed += 1 if __name__ == "__main__":