本文介绍了使用paddle AIstudio时遇到的坑~做一个避坑记录。
也可以到AI studio的实际项目里体验一下~
我发现了一篇高质量的实训项目,使用免费算力即可一键运行,还能额外获取8小时免费GPU运行时长,快来Fork一下体验吧。
【华为比赛】车道渲染数据智能质检Rank56方案
有任何疑问和建议欢迎大家交流~
一、optimizer
paddle.optimizer 目录下包含飞桨框架支持的优化器算法相关的API与学习率衰减相关的API。在构建模型训练过程时直接调用生成对应的实例即可使用,目前热门的深度学习框架基本都是如此方便。
例如:
optimizer = optim.AdamW(parameters = model.parameters(),learning_rate=0.001)
之后在训练过程中,通过其内部的step(),clear_grad()函数(只在只在动态图中使用),既可以完成参数的更新。
如官方文档中的例子:
import paddle
linear = paddle.nn.Linear(10, 10)
inp = paddle.rand([10,10], dtype="float32")
out = linear(inp)
loss = paddle.mean(out)
beta1 = paddle.to_tensor([0.9], dtype="float32")
beta2 = paddle.to_tensor([0.99], dtype="float32")
adam = paddle.optimizer.AdamW(learning_rate=0.1,
parameters=linear.parameters(),
beta1=beta1,
beta2=beta2,
weight_decay=0.01)
out.backward()
adam.step()
adam.clear_grad()
# Note that the learning_rate of linear_2 is 0.01.
linear_1 = paddle.nn.Linear(10, 10)
linear_2 = paddle.nn.Linear(10, 10)
inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
out = linear_1(inp)
out = linear_2(out)
loss = paddle.mean(out)
adam = paddle.optimizer.AdamW(
learning_rate=0.1,
parameters=[{
'params': linear_1.parameters()
}, {
'params': linear_2.parameters(),
'weight_decay': 0.001,
'learning_rate': 0.1, #此处学习率为0.1*0.1 = 0.01
# 错误示范
# 'learning_rate':0.1*scheduler
'beta1': 0.8
}],
weight_decay=0.01,
beta1=0.9)
#反向传播→更新权重→清除梯度
out.backward()
adam.step()
adam.clear_grad()
二、使用学习率衰减策略
该部分结构也是非常的简单易用,只需要调用相关的类,生成对应实例,将实例作为learning_rate参数传入optimizer就可以完成。
例如分段设置学习率。
lr = 10e-4
# 模型优化器
scheduler = optim.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[lr, 0.5*lr, 0.1*lr, 0.01*lr], verbose=True)
optimizer = optim.AdamW(parameters = model.parameters(),learning_rate=scheduler)
# learning_rate = lr if epoch < 3
# learning_rate = 0.5*lr if 3 <= epoch < 6
# learning_rate = 0.1*lr if 6 <= epoch < 9
# learning_rate = 0.01*lr if 9 <= epoch
三、不同参数不同学习率
optimizer还提供了不同参数组使用不同学习率的方式,具体可看上方的官方例子,只要将传入的参数组通过一个字典的列表传入,就可以对不同的参数组使用不同的学习率进行更新。
但是这里要注意的是,传入参数组学习率表示基本学习率的比例。 所以此时如果传入为scheduler的话,就会产生报错。
四、其他错误
当出现错误:List out of index时,需要考虑是否传入的参数组列表为空。