diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1c5c3d7a..7c9e4e46 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -160,12 +160,19 @@ class ConfigMixin: if "dtype" in unused_kwargs: init_dict["dtype"] = unused_kwargs.pop("dtype") + # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) + return_tuple = (model,) + + # Flax schedulers have a state, so return it. + if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False): + state = model.create_state() + return_tuple += (state,) if return_unused_kwargs: - return model, unused_kwargs + return return_tuple + (unused_kwargs,) else: - return model + return return_tuple if len(return_tuple) > 1 else model @classmethod def get_config_dict( diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index b7de33d2..6cfd7ae3 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -437,8 +437,8 @@ class FlaxDiffusionPipeline(ConfigMixin): loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params elif issubclass(class_obj, SchedulerMixin): - loaded_sub_model = load_method(loadable_folder) - params[name] = loaded_sub_model.create_state() + loaded_sub_model, scheduler_state = load_method(loadable_folder) + params[name] = scheduler_state else: loaded_sub_model = load_method(loadable_folder) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 30c873b4..d81d6660 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): stable diffusion. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 4c8c4381..83445056 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): stable diffusion. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self,