Merge pull request #5327 from smirkingface/master
Fixed safety checker for ckpt files written with pytorch >=1.13
This commit is contained in:
commit
c3777777d0
|
@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
||||
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
||||
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
def check_zip_filenames(filename, names):
|
||||
for name in names:
|
||||
if name in allowed_zip_names:
|
||||
continue
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
|
@ -82,8 +80,14 @@ def check_pt(filename, extra_handler):
|
|||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
with z.open('archive/data.pkl') as file:
|
||||
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
raise Exception(f"data.pkl not found in {filename}")
|
||||
if len(data_pkl_filenames) > 1:
|
||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
with z.open(data_pkl_filenames[0]) as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
|
Loading…
Reference in New Issue