Fix formula for noise levels in Karras scheduler and tests (#627)

fix formula for noise levels in karras scheduler and tests
This commit is contained in:
Grigory Sizov 2022-09-24 18:24:08 +02:00 committed by GitHub
parent d0aa899f0e
commit 35e9209601
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 3 deletions

View File

@ -110,7 +110,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [
(
self.config.sigma_max
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps

View File

@ -113,7 +113,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
schedule = [
(
self.config.sigma_max
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in timesteps

View File

@ -1104,7 +1104,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow