Allow trailing comma in learning rate

This commit is contained in:
Muhammad Rizqi Nur 2022-10-29 15:37:24 +07:00
parent 35c45df28b
commit a5f3adbdd7
1 changed files with 20 additions and 13 deletions

View File

@ -11,23 +11,30 @@ class LearnScheduleIterator:
self.rates = []
self.it = 0
self.maxit = 0
for i, pair in enumerate(pairs):
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
try:
for i, pair in enumerate(pairs):
if not pair.strip():
continue
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
return
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
elif step == -1:
else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
assert self.rates
except (ValueError, AssertionError):
raise Exception("Invalid learning rate schedule")
def __iter__(self):
return self