Return Flax scheduler state (#601)
* Optionally return state in from_config. Useful for Flax schedulers. * has_state is now a property, make check more strict. I don't check the class is `SchedulerMixin` to prevent circular dependencies. It should be enough that the class name starts with "Flax" the object declares it "has_state" and the "create_state" exists too. * Use state in pipeline from_pretrained. * Make style
This commit is contained in:
parent
e72f1a8a71
commit
a9fdb3de9e
|
@ -160,12 +160,19 @@ class ConfigMixin:
|
|||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
return_tuple = (model,)
|
||||
|
||||
# Flax schedulers have a state, so return it.
|
||||
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
|
||||
state = model.create_state()
|
||||
return_tuple += (state,)
|
||||
|
||||
if return_unused_kwargs:
|
||||
return model, unused_kwargs
|
||||
return return_tuple + (unused_kwargs,)
|
||||
else:
|
||||
return model
|
||||
return return_tuple if len(return_tuple) > 1 else model
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
|
|
|
@ -437,8 +437,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
|
||||
params[name] = loaded_params
|
||||
elif issubclass(class_obj, SchedulerMixin):
|
||||
loaded_sub_model = load_method(loadable_folder)
|
||||
params[name] = loaded_sub_model.create_state()
|
||||
loaded_sub_model, scheduler_state = load_method(loadable_folder)
|
||||
params[name] = scheduler_state
|
||||
else:
|
||||
loaded_sub_model = load_method(loadable_folder)
|
||||
|
||||
|
|
|
@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
stable diffusion.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
stable diffusion.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue