Return Flax scheduler state (#601)

* Optionally return state in from_config.

Useful for Flax schedulers.

* has_state is now a property, make check more strict.

I don't check the class is `SchedulerMixin` to prevent circular
dependencies. It should be enough that the class name starts with "Flax"
the object declares it "has_state" and the "create_state" exists too.

* Use state in pipeline from_pretrained.

* Make style
This commit is contained in:
Pedro Cuenca 2022-09-21 22:25:27 +02:00 committed by GitHub
parent e72f1a8a71
commit a9fdb3de9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 4 deletions

View File

@ -160,12 +160,19 @@ class ConfigMixin:
if "dtype" in unused_kwargs: if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype") init_dict["dtype"] = unused_kwargs.pop("dtype")
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict) model = cls(**init_dict)
return_tuple = (model,)
# Flax schedulers have a state, so return it.
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
state = model.create_state()
return_tuple += (state,)
if return_unused_kwargs: if return_unused_kwargs:
return model, unused_kwargs return return_tuple + (unused_kwargs,)
else: else:
return model return return_tuple if len(return_tuple) > 1 else model
@classmethod @classmethod
def get_config_dict( def get_config_dict(

View File

@ -437,8 +437,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin): elif issubclass(class_obj, SchedulerMixin):
loaded_sub_model = load_method(loadable_folder) loaded_sub_model, scheduler_state = load_method(loadable_folder)
params[name] = loaded_sub_model.create_state() params[name] = scheduler_state
else: else:
loaded_sub_model = load_method(loadable_folder) loaded_sub_model = load_method(loadable_folder)

View File

@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
stable diffusion. stable diffusion.
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,

View File

@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
stable diffusion. stable diffusion.
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,