Allow trailing comma in learning rate

This commit is contained in:
Muhammad Rizqi Nur 2022-10-29 15:37:24 +07:00
parent 35c45df28b
commit a5f3adbdd7
1 changed files with 20 additions and 13 deletions

View File

@ -11,23 +11,30 @@ class LearnScheduleIterator:
self.rates = [] self.rates = []
self.it = 0 self.it = 0
self.maxit = 0 self.maxit = 0
for i, pair in enumerate(pairs): try:
tmp = pair.split(':') for i, pair in enumerate(pairs):
if len(tmp) == 2: if not pair.strip():
step = int(tmp[1]) continue
if step > cur_step: tmp = pair.split(':')
self.rates.append((float(tmp[0]), min(step, max_steps))) if len(tmp) == 2:
self.maxit += 1 step = int(tmp[1])
if step > max_steps: if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
return
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return return
elif step == -1: else:
self.rates.append((float(tmp[0]), max_steps)) self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1 self.maxit += 1
return return
else: assert self.rates
self.rates.append((float(tmp[0]), max_steps)) except (ValueError, AssertionError):
self.maxit += 1 raise Exception("Invalid learning rate schedule")
return
def __iter__(self): def __iter__(self):
return self return self