[PNDM Scheduler] format timesteps attrs to np arrays (#273)

* format timesteps attrs to np arrays in pndm scheduler
because lists don't get formatted to tensors in `self.set_format`

* convert to long type to use timesteps as indices for tensors

* add scheduler set_format test

* fix `_timesteps` type

* make style with black 22.3.0 and isort 5.10.1

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Nouamane Tazi 2022-08-31 13:12:08 +01:00 committed by GitHub
parent 7eb6dfc607
commit b64c522759
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 6 deletions

View File

@ -103,22 +103,24 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self._offset = offset
self._timesteps = [t + self._offset for t in self._timesteps]
self._timesteps = np.array([t + self._offset for t in self._timesteps])
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
self.prk_timesteps = []
self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:]))
self.prk_timesteps = np.array([])
self.plms_timesteps = (self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])[::-1].copy()
else:
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
self.plms_timesteps = self._timesteps[:-3][
::-1
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
self.timesteps = self.prk_timesteps + self.plms_timesteps
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
self.ets = []
self.counter = 0

View File

@ -485,6 +485,35 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_set_format(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
for key, value in vars(scheduler).items():
# we only allow `ets` attr to be a list
assert not isinstance(value, list) or key in [
"ets"
], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}"
# check if `scheduler.set_format` does convert correctly attrs to pt format
for key, value in vars(scheduler_pt).items():
# we only allow `ets` attr to be a list
assert not isinstance(value, list) or key in [
"ets"
], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
assert not isinstance(
value, np.ndarray
), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)