diff --git a/utils/isolate_rng.py b/utils/isolate_rng.py index 711d629..879f4fe 100644 --- a/utils/isolate_rng.py +++ b/utils/isolate_rng.py @@ -34,7 +34,11 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: "python": python_get_rng_state(), } if include_cuda: - states["torch.cuda"] = torch.cuda.get_rng_state_all() + try: + states["torch.cuda"] = torch.cuda.get_rng_state_all() + except RuntimeError: + # CUDA initialization failure. + pass return states