diff --git a/modules/safe.py b/modules/safe.py index 479c8b86d..ec23a53c4 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -103,7 +103,7 @@ def check_pt(filename, extra_handler): def load(filename, *args, **kwargs): - return load_with_extra(filename, *args, **kwargs) + return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs) def load_with_extra(filename, extra_handler=None, *args, **kwargs): @@ -151,5 +151,42 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): return unsafe_torch_load(filename, *args, **kwargs) +class Extra: + """ + A class for temporarily setting the global handler for when you can't explicitly call load_with_extra + (because it's not your code making the torch.load call). The intended use is like this: + +``` +import torch +from modules import safe + +def handler(module, name): + if module == 'torch' and name in ['float64', 'float16']: + return getattr(torch, name) + + return None + +with safe.Extra(handler): + x = torch.load('model.pt') +``` + """ + + def __init__(self, handler): + self.handler = handler + + def __enter__(self): + global global_extra_handler + + assert global_extra_handler is None, 'already inside an Extra() block' + global_extra_handler = self.handler + + def __exit__(self, exc_type, exc_val, exc_tb): + global global_extra_handler + + global_extra_handler = None + + unsafe_torch_load = torch.load torch.load = load +global_extra_handler = None +