PyTorch 模型进阶训练技巧
- 自定义损失函数 如 cross_entropy + L2正则化
- 动态调整学习率 如每十次 *0.1
典型案例:loss上下震荡
1、自定义损失函数
-
1、PyTorch已经提供了很多常用的损失函数,但是有些非通用的损失函数并未提供,比如:DiceLoss、HuberLoss…等
-
2、模型如果出现loss震荡,在经过调整数据集或超参后,现象依然存在,非通用损失函数或自定义损失函数针对特定模型会有更好的效果
比如:DiceLoss是医学影像分割常用的损失函数,定义如下:
-
Dice系数, 是一种集合相似度度量函数,通常用于计算两个样本的相似度(值范围为 [0, 1]):
-
∣X∩Y∣表示X和Y之间的交集,