diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7c098adb..56356461 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -179,7 +179,7 @@ class ModelTesterMixin: loss.backward() ema_model.step(model) - def test_scheduler_outputs_equivalence(self): + def test_outputs_equivalence(self): def set_nan_tensor_to_zero(t): # Temporary fallback until `aten::_index_put_impl_` is implemented in mps # Track progress in https://github.com/pytorch/pytorch/issues/77764