Return Flax scheduler state (#601)
* Optionally return state in from_config. Useful for Flax schedulers. * has_state is now a property, make check more strict. I don't check the class is `SchedulerMixin` to prevent circular dependencies. It should be enough that the class name starts with "Flax" the object declares it "has_state" and the "create_state" exists too. * Use state in pipeline from_pretrained. * Make style
This commit is contained in:
parent
e72f1a8a71
commit
a9fdb3de9e
|
@ -160,12 +160,19 @@ class ConfigMixin:
|
||||||
if "dtype" in unused_kwargs:
|
if "dtype" in unused_kwargs:
|
||||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||||
|
|
||||||
|
# Return model and optionally state and/or unused_kwargs
|
||||||
model = cls(**init_dict)
|
model = cls(**init_dict)
|
||||||
|
return_tuple = (model,)
|
||||||
|
|
||||||
|
# Flax schedulers have a state, so return it.
|
||||||
|
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
|
||||||
|
state = model.create_state()
|
||||||
|
return_tuple += (state,)
|
||||||
|
|
||||||
if return_unused_kwargs:
|
if return_unused_kwargs:
|
||||||
return model, unused_kwargs
|
return return_tuple + (unused_kwargs,)
|
||||||
else:
|
else:
|
||||||
return model
|
return return_tuple if len(return_tuple) > 1 else model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_dict(
|
def get_config_dict(
|
||||||
|
|
|
@ -437,8 +437,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||||
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
|
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
|
||||||
params[name] = loaded_params
|
params[name] = loaded_params
|
||||||
elif issubclass(class_obj, SchedulerMixin):
|
elif issubclass(class_obj, SchedulerMixin):
|
||||||
loaded_sub_model = load_method(loadable_folder)
|
loaded_sub_model, scheduler_state = load_method(loadable_folder)
|
||||||
params[name] = loaded_sub_model.create_state()
|
params[name] = scheduler_state
|
||||||
else:
|
else:
|
||||||
loaded_sub_model = load_method(loadable_folder)
|
loaded_sub_model = load_method(loadable_folder)
|
||||||
|
|
||||||
|
|
|
@ -105,6 +105,10 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
stable diffusion.
|
stable diffusion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_state(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -113,6 +113,10 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
stable diffusion.
|
stable diffusion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_state(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue