Merge pull request #5327 from smirkingface/master
Fixed safety checker for ckpt files written with pytorch >=1.13
This commit is contained in:
commit
c3777777d0
|
@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||||
|
|
||||||
|
|
||||||
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
||||||
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
||||||
|
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||||
|
|
||||||
def check_zip_filenames(filename, names):
|
def check_zip_filenames(filename, names):
|
||||||
for name in names:
|
for name in names:
|
||||||
if name in allowed_zip_names:
|
|
||||||
continue
|
|
||||||
if allowed_zip_names_re.match(name):
|
if allowed_zip_names_re.match(name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -82,8 +80,14 @@ def check_pt(filename, extra_handler):
|
||||||
# new pytorch format is a zip file
|
# new pytorch format is a zip file
|
||||||
with zipfile.ZipFile(filename) as z:
|
with zipfile.ZipFile(filename) as z:
|
||||||
check_zip_filenames(filename, z.namelist())
|
check_zip_filenames(filename, z.namelist())
|
||||||
|
|
||||||
with z.open('archive/data.pkl') as file:
|
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||||
|
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||||
|
if len(data_pkl_filenames) == 0:
|
||||||
|
raise Exception(f"data.pkl not found in {filename}")
|
||||||
|
if len(data_pkl_filenames) > 1:
|
||||||
|
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||||
|
with z.open(data_pkl_filenames[0]) as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.extra_handler = extra_handler
|
unpickler.extra_handler = extra_handler
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
Loading…
Reference in New Issue