不平衡数据分类网络-Pytorch试验

不平衡数据分类网络-Pytorch试验

注意:本试验在参考此代码的基础上。为方便起见,之后简称A

1. 准备数据(CIFAR-10数据集

1.1 制作不平衡数据集 (下载的为平衡数据集)

脚本:cifar10_to_png.py脚本:image2train_test.py

直接从原始CIFAR-10采样,通过控制每一类采样的个数,就可以产生类别不平衡的训练数据。
步骤
1)在A提取图片的基础上 ;

2)将数据集分成训练集和测试集 ;

3)在训练集中根据自定义的类别占比,采样不同数量的类别,得到不平衡训练集;

4)在测试集中,采样相同小数量的类别,得到平衡测试集。

PS:为了尽可能近似实际项目中的情况,故训练集中的样本数量设置的比较少。
且第二步的意义是为了防止数据泄露。

2. 数据加载 (参考A)

3. 搭建网络 (参考A)

采用的VGG16网络 参考此博客介绍

4. 训练网络

4.1 训练普通交叉熵损失函数的网络

loss = celoss(outputs, labels)  # 计算损失值

4.2 训练Class-Balanced Loss 的网络

Class-Balanced Loss Based on Effective Number of Samples论文解读参考此博客
在这里插入图片描述
β \beta β为常数,论文中设置为 ( N − 1 ) / N (N-1)/N (N1)/N N N N 为总样本数目。 n y n_y ny 为第 y y y 类的样本数目。

训练时遇到bugUserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at …\c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

解决办法:
这是pytorch1.9的bug,下个版本将修复,我将pytorch降级成1.8就不报这个错了。

5. 训练结果

5.1 第一组试验

数据集:
1)训练集:10类不平衡样本按如下比例分配

trainnum = 1000
class_ratio = [19, 17, 15, 13, 11, 9, 7, 5, 3, 1]

2)测试集:10类平衡样本每类数量为:

testnum = 50

混淆矩阵如有不懂参考此博客具体代码实现

e p o c h = 500 epoch = 500 epoch=500 时,在测试集上得到的混淆矩阵如下:
在这里插入图片描述
e p o c h = 500 epoch = 500 epoch=500 时,利用类平衡损失函数,在测试集得到的混淆矩阵为:
在这里插入图片描述

图1 交叉熵损失函数
图2 类平衡损失函数

e p o c h = 2500 epoch = 2500 epoch=2500 时,在测试集上得到的混淆矩阵如下:在这里插入图片描述
e p o c h = 2000 epoch = 2000 epoch=2000 时,利用类平衡损失函数,在测试集得到的混淆矩阵为:
在这里插入图片描述

图1 交叉熵损失函数
图2 类平衡损失函数

结论:类平衡损失函数效果不明显。

可能有如下原因:

1)整体样本数量不是特别多,同类样本之间的特征不是特别统一。后续补做试验

2)没根据Loss去判断网络是否收敛。后续修改程序
5.2 第二组实验
将训练集扩大至5000个,测试集仍然是 50 ∗ 10 50*10 5010,加入loss曲线, e p o c h = 1000 epoch = 1000 epoch=1000 利用类平衡损失函数,结果如下:
loss曲线

在这里插入图片描述

图1 训练集=1000
图2 训练集=5000

召回率与精度如下:

labelname  recall-5000  recall-1000  precision-5000	precision-1000
airplane  	52.0%  	 		24.3%		44.8%			36.0%
automobile  90.0%  	 		35.5%		48.9%			76.0%
bird  		32.0%  	 		29.3%		69.6%			34.0%
cat  		38.0%  	 		30.4%		24.4%			34.0%
deer  		46.0%  	 		31.3%		46.9%			42.0%
dog  		76.0%  	 		44.1%		34.9%			30.0%
frog  		26.0%  	 		51.6%		65.0%			64.0%
horse  		32.0%  	 		32.0%		80.0%			16.0%
ship  		48.0%  	 		84.6%		57.1%			22.0%
truck  		14.0%  	 		25.0%		77.8%			2.0%

结论:扩大训练集,训练效果更好。

5.3 第三组实验
利用图像增强将不平衡训练集,调整至平衡数据集。
在这里插入图片描述
在这里插入图片描述

图1 训练集=5000利用类平衡损失函数
图2 训练集=5000用图像增强实现重采样

召回率与精度如下:

labelname  recall-lossblance  recall-resample  precision-lossbalance	precision-resample
airplane  		52.0%  	 		92.0%				44.8%						24.9%
automobile  	90.0%  	 		92.0%				48.9%						59.0%
bird  			32.0%  	 		38.0%				69.6%						65.5%
cat  			38.0%  	 		56.0%				24.4%						32.6%
deer  			46.0%  	 		48.0%				46.9%						60.0%
dog  			76.0%  	 		30.0%				34.9%						65.2%
frog  			26.0%  	 		38.0%				65.0%						76.0%
horse  			32.0%  	 		32.0%				80.0%						76.2%
ship  			48.0%  	 		20.0%				57.1%						83.3%
truck  			14.0%  	 		2.0%				77.8%						100.0%

结论:综合来说,重采样效果更好,不过也可能是由于重采样的原因,导致小类样本训练可能存在过拟合(对有限的样本特征学习的很好,反而不敢预测),导致其召回率很低,精度可以。

5.4 第四组实验

结合重采样和重加权的方法,进行训练,结果如下:
在这里插入图片描述

在这里插入图片描述

图1 用图像增强实现重采样
图2 重采样和重加权结合

召回率与精度如下:

labelname  recall-lossblance-rs  recall-resample  precision-lossbalance-rs	precision-resample
airplane  		52.0%  	 		92.0%				35.1%						24.9%
automobile  	90.0%  	 		92.0%				45.9%						59.0%
bird  			50.0%  	 		38.0%				42.4%						65.5%
cat  			56.0%  	 		56.0%				27.2%						32.6%
deer  			48.0%  	 		48.0%				33.3%						60.0%
dog  			22.0%  	 		30.0%				55.0%						65.2%
frog  			56.0%  	 		38.0%				60.9%						76.0%
horse  			34.0%  	 		32.0%				70.8%						76.2%
ship  			4.0%  	 		20.0%				66.7%						83.3%
truck  			2.0%  	 		2.0%				100.0%						100.0%

结论:效果似乎没有单纯重采样效果好
原因:???
下一步寻找针对样本不平衡问题的评价指标

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值