From a9fdb3de9ee4f247cf95b7ceba48236cb1abc36a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 22:25:27 +0200 Subject: [PATCH] Return Flax scheduler state (#601) * Optionally return state in from_config. Useful for Flax schedulers. * has_state is now a property, make check more strict. I don't check the class is `SchedulerMixin` to prevent circular dependencies. It should be enough that the class name starts with "Flax" the object declares it "has_state" and the "create_state" exists too. * Use state in pipeline from_pretrained. * Make style --- src/diffusers/configuration_utils.py | 11 +++++++++-- src/diffusers/pipeline_flax_utils.py | 4 ++-- src/diffusers/schedulers/scheduling_ddim_flax.py | 4 ++++ src/diffusers/schedulers/scheduling_pndm_flax.py | 4 ++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1c5c3d7a..7c9e4e46 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -160,12 +160,19 @@ class ConfigMixin: if "dtype" in unused_kwargs: init_dict["dtype"] = unused_kwargs.pop("dtype") + # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) + return_tuple = (model,) + + # Flax schedulers have a state, so return it. + if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False): + state = model.create_state() + return_tuple += (state,) if return_unused_kwargs: - return model, unused_kwargs + return return_tuple + (unused_kwargs,) else: - return model + return return_tuple if len(return_tuple) > 1 else model @classmethod def get_config_dict( diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index b7de33d2..6cfd7ae3 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -437,8 +437,8 @@ class FlaxDiffusionPipeline(ConfigMixin): loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params elif issubclass(class_obj, SchedulerMixin): - loaded_sub_model = load_method(loadable_folder) - params[name] = loaded_sub_model.create_state() + loaded_sub_model, scheduler_state = load_method(loadable_folder) + params[name] = scheduler_state else: loaded_sub_model = load_method(loadable_folder) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 30c873b4..d81d6660 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): stable diffusion. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 4c8c4381..83445056 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): stable diffusion. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self,