2022-10-12 11:49:47 -06:00
import tqdm
2022-10-11 13:03:05 -06:00
2022-10-12 11:49:47 -06:00
class LearnScheduleIterator :
2022-10-11 13:03:05 -06:00
def __init__ ( self , learn_rate , max_steps , cur_step = 0 ) :
2022-10-12 11:49:47 -06:00
"""
2022-10-29 02:42:51 -06:00
specify learn_rate as " 0.001:100, 0.00001:1000, 1e-5:10000 " to have lr of 0.001 until step 100 , 0.00001 until 1000 , and 1e-5 until 10000
2022-10-12 11:49:47 -06:00
"""
2022-10-11 13:03:05 -06:00
pairs = learn_rate . split ( ' , ' )
self . rates = [ ]
self . it = 0
self . maxit = 0
2022-10-29 02:37:24 -06:00
try :
for i , pair in enumerate ( pairs ) :
if not pair . strip ( ) :
continue
tmp = pair . split ( ' : ' )
if len ( tmp ) == 2 :
step = int ( tmp [ 1 ] )
if step > cur_step :
self . rates . append ( ( float ( tmp [ 0 ] ) , min ( step , max_steps ) ) )
self . maxit + = 1
if step > max_steps :
return
elif step == - 1 :
self . rates . append ( ( float ( tmp [ 0 ] ) , max_steps ) )
self . maxit + = 1
2022-10-11 13:03:05 -06:00
return
2022-10-29 02:37:24 -06:00
else :
2022-10-11 13:03:05 -06:00
self . rates . append ( ( float ( tmp [ 0 ] ) , max_steps ) )
self . maxit + = 1
return
2022-10-29 02:37:24 -06:00
assert self . rates
except ( ValueError , AssertionError ) :
2022-10-29 02:42:51 -06:00
raise Exception ( ' Invalid learning rate schedule. It should be a number or, for example, like " 0.001:100, 0.00001:1000, 1e-5:10000 " to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000. ' )
2022-10-29 02:37:24 -06:00
2022-10-11 13:03:05 -06:00
def __iter__ ( self ) :
return self
def __next__ ( self ) :
if self . it < self . maxit :
self . it + = 1
return self . rates [ self . it - 1 ]
else :
raise StopIteration
2022-10-12 11:49:47 -06:00
class LearnRateScheduler :
def __init__ ( self , learn_rate , max_steps , cur_step = 0 , verbose = True ) :
self . schedules = LearnScheduleIterator ( learn_rate , max_steps , cur_step )
( self . learn_rate , self . end_step ) = next ( self . schedules )
self . verbose = verbose
if self . verbose :
print ( f ' Training at rate of { self . learn_rate } until step { self . end_step } ' )
self . finished = False
2022-10-28 04:16:23 -06:00
def step ( self , step_number ) :
2022-10-28 07:48:08 -06:00
if step_number < self . end_step :
2022-10-28 04:16:23 -06:00
return False
2022-10-12 11:49:47 -06:00
try :
( self . learn_rate , self . end_step ) = next ( self . schedules )
2022-10-28 04:16:23 -06:00
except StopIteration :
2022-10-12 11:49:47 -06:00
self . finished = True
2022-10-28 04:16:23 -06:00
return False
return True
def apply ( self , optimizer , step_number ) :
if not self . step ( step_number ) :
2022-10-12 11:49:47 -06:00
return
if self . verbose :
tqdm . tqdm . write ( f ' Training at rate of { self . learn_rate } until step { self . end_step } ' )
for pg in optimizer . param_groups :
pg [ ' lr ' ] = self . learn_rate