Fix formula for noise levels in Karras scheduler and tests (#627)
fix formula for noise levels in karras scheduler and tests
This commit is contained in:
parent
d0aa899f0e
commit
35e9209601
|
@ -110,7 +110,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.schedule = [
|
||||
(
|
||||
self.config.sigma_max
|
||||
self.config.sigma_max**2
|
||||
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
|
||||
)
|
||||
for i in self.timesteps
|
||||
|
|
|
@ -113,7 +113,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
|
||||
schedule = [
|
||||
(
|
||||
self.config.sigma_max
|
||||
self.config.sigma_max**2
|
||||
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
|
||||
)
|
||||
for i in timesteps
|
||||
|
|
|
@ -1104,7 +1104,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
|
||||
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
|
|
Loading…
Reference in New Issue