转载自:https://blog.csdn.net/u011622208/article/details/118555590
代码:
import numpy as np
import matplotlib.pyplot as plt
def Warmup_poly():
warm_start_lr = 1e-4
warm_steps = 5
lr0 = 1e-2
nsteps = 500
power = 0.9
lrs = []
# warmup
warm_factor = (lr0 / warm_start_lr) ** (1 / warm_steps)
print('warm_factor', warm_factor)
for i in range(warm_steps):
warm_lr = warm_start_lr * warm_factor ** i
lrs.append(warm_lr)
print('lrs', lrs)
# poly lr
for j in range(warm_steps, nsteps):
lr = lr0 * (1 - (j - warm_steps) / (nsteps - warm_steps)) ** power
lrs.append(lr)
lrs = np.array(lrs)
plt.plot(lrs)
plt.xlabel("Iters")
plt.ylabel("lr")
plt.show()
if __name__ == "__main__":
Warmup_poly()
结果: