warmup pytorch实现 plot版本

import torch
from torch.nn import Linear, Sequential
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# 生成数据
x_data = torch.linspace(0, 50, 100)
x_data = torch.unsqueeze(x_data, 0)
y_data = x_data ** 2 + torch.randn(100) * 20
x_data = x_data.permute(1, 0)
y_data = y_data.permute(1, 0)
print(x_data.shape)
print(y_data.shape)

train_dataset = TensorDataset(x_data,y_data)
train_loader = DataLoader(train_dataset,batch_size=2,shuffle=True)
train_data = iter(train_loader)
train_x, train_y = next(train_data)

class Net(torch.nn.Module):
    def __init__(self, hidden_num=10):
        super(Net, self).__init__()
        self.layer = Sequential(
            Linear(1, hidden_num),
            Linear(hidden_num, 1),
        )

    def forward(self, x):
        x = self.layer(x)
        return x

net = Net()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=10)
# scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch:0.95**epoch)
import math

warm_up_with_cosine_lr = lambda epoch: epoch / 4 if epoch <= 4 else 0.5 * ( math.cos((epoch - 4) /(20 - 4) * math.pi) + 1)
# warm_up_with_cosine_lr = lambda epoch: epoch / 4 if epoch <= 4 else max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=warm_up_with_cosine_lr)

# warm_up_with_multistep_lr = lambda epoch: epoch / 4 if epoch <= 4 else 0.1**len([m for m in [20,40] if m <= epoch])
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_multistep_lr)


lr_list = []
for i in range(20):
    lr_list.append(scheduler.get_last_lr())
    print(scheduler.get_last_lr())
    for x, y in train_loader:
        y_pred = net(x)
        loss = criterion(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()

# 绘制lr变化曲线
plt.plot(lr_list)
plt.legend(labels=['LambdaLR'])
plt.show()
print(scheduler)

warmup就是调节学习率,每一个epoch都进行修改lr,代码实现的就是这个功能,只考虑前xx个epoch采用一种缓慢上升的策略到init lr后,再设置一种函数来缓慢调节下降的学习率

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值