diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 224fe471..6850d9ca 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -278,24 +278,19 @@ class EMAModel: self.decay = decay self.optimization_step = 0 - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - value = (1 + optimization_step) / (10 + optimization_step) - return 1 - min(self.decay, value) - @torch.no_grad() def step(self, parameters): parameters = list(parameters) self.optimization_step += 1 - self.decay = self.get_decay(self.optimization_step) + + # Compute the decay factor for the exponential moving average. + value = (1 + self.optimization_step) / (10 + self.optimization_step) + one_minus_decay = 1 - min(self.decay, value) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - tmp = self.decay * (s_param - param) - s_param.sub_(tmp) + s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param)