diff --git a/train.py b/train.py index 018d851..0bbcf39 100644 --- a/train.py +++ b/train.py @@ -677,7 +677,8 @@ def main(args): """ handles sigterm """ - if threading.current_thread().__class__.__name__ == '_MainThread': + is_main_thread = (torch.utils.data.get_worker_info() == None) + if is_main_thread: global interrupted if not interrupted: interrupted=True