diff --git a/riffusion/streamlit/util.py b/riffusion/streamlit/util.py index c1fce73..7f4c865 100644 --- a/riffusion/streamlit/util.py +++ b/riffusion/streamlit/util.py @@ -2,6 +2,7 @@ Streamlit utilities (mostly cached wrappers around riffusion code). """ import io +import threading import typing as T import pydub @@ -106,6 +107,14 @@ def get_scheduler(scheduler: str, config: T.Any) -> T.Any: raise ValueError(f"Unknown scheduler {scheduler}") +@st.experimental_singleton +def pipeline_lock() -> threading.Lock: + """ + Singleton lock used to prevent concurrent access to any model pipeline. + """ + return threading.Lock() + + @st.experimental_singleton def load_stable_diffusion_img2img_pipeline( checkpoint: str = "riffusion/riffusion-model-v1", @@ -149,22 +158,23 @@ def run_txt2img( """ Run the text to image pipeline with caching. """ - pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler) + with pipeline_lock(): + pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler) - generator_device = "cpu" if device.lower().startswith("mps") else device - generator = torch.Generator(device=generator_device).manual_seed(seed) + generator_device = "cpu" if device.lower().startswith("mps") else device + generator = torch.Generator(device=generator_device).manual_seed(seed) - output = pipeline( - prompt=prompt, - num_inference_steps=num_inference_steps, - guidance_scale=guidance, - negative_prompt=negative_prompt or None, - generator=generator, - width=width, - height=height, - ) + output = pipeline( + prompt=prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance, + negative_prompt=negative_prompt or None, + generator=generator, + width=width, + height=height, + ) - return output["images"][0] + return output["images"][0] @st.experimental_singleton @@ -269,31 +279,32 @@ def run_img2img( scheduler: str = SCHEDULER_OPTIONS[0], progress_callback: T.Optional[T.Callable[[float], T.Any]] = None, ) -> Image.Image: - pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler) + with pipeline_lock(): + pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler) - generator_device = "cpu" if device.lower().startswith("mps") else device - generator = torch.Generator(device=generator_device).manual_seed(seed) + generator_device = "cpu" if device.lower().startswith("mps") else device + generator = torch.Generator(device=generator_device).manual_seed(seed) - num_expected_steps = max(int(num_inference_steps * denoising_strength), 1) + num_expected_steps = max(int(num_inference_steps * denoising_strength), 1) - def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None: - if progress_callback is not None: - progress_callback(step / num_expected_steps) + def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None: + if progress_callback is not None: + progress_callback(step / num_expected_steps) - result = pipeline( - prompt=prompt, - image=init_image, - strength=denoising_strength, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - negative_prompt=negative_prompt or None, - num_images_per_prompt=1, - generator=generator, - callback=callback, - callback_steps=1, - ) + result = pipeline( + prompt=prompt, + image=init_image, + strength=denoising_strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt or None, + num_images_per_prompt=1, + generator=generator, + callback=callback, + callback_steps=1, + ) - return result.images[0] + return result.images[0] class StreamlitCounter: