[PNDM Scheduler] format timesteps attrs to np arrays (#273)
* format timesteps attrs to np arrays in pndm scheduler because lists don't get formatted to tensors in `self.set_format` * convert to long type to use timesteps as indices for tensors * add scheduler set_format test * fix `_timesteps` type * make style with black 22.3.0 and isort 5.10.1 Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
7eb6dfc607
commit
b64c522759
|
@ -103,22 +103,24 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
)
|
||||
self._offset = offset
|
||||
self._timesteps = [t + self._offset for t in self._timesteps]
|
||||
self._timesteps = np.array([t + self._offset for t in self._timesteps])
|
||||
|
||||
if self.config.skip_prk_steps:
|
||||
# for some models like stable diffusion the prk steps can/should be skipped to
|
||||
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
|
||||
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
|
||||
self.prk_timesteps = []
|
||||
self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:]))
|
||||
self.prk_timesteps = np.array([])
|
||||
self.plms_timesteps = (self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])[::-1].copy()
|
||||
else:
|
||||
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
|
||||
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
|
||||
self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
|
||||
self.plms_timesteps = self._timesteps[:-3][
|
||||
::-1
|
||||
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
|
||||
|
||||
self.timesteps = self.prk_timesteps + self.plms_timesteps
|
||||
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
|
||||
|
||||
self.ets = []
|
||||
self.counter = 0
|
||||
|
|
|
@ -485,6 +485,35 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_set_format(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
|
||||
for key, value in vars(scheduler).items():
|
||||
# we only allow `ets` attr to be a list
|
||||
assert not isinstance(value, list) or key in [
|
||||
"ets"
|
||||
], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}"
|
||||
|
||||
# check if `scheduler.set_format` does convert correctly attrs to pt format
|
||||
for key, value in vars(scheduler_pt).items():
|
||||
# we only allow `ets` attr to be a list
|
||||
assert not isinstance(value, list) or key in [
|
||||
"ets"
|
||||
], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
|
||||
assert not isinstance(
|
||||
value, np.ndarray
|
||||
), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue