开源仓库
CNN自编码器
前面的文章中我们模仿MAE的思路用CNN做了一个掩码自编码器
下游任务
我们的CNN_Masked_Autoencoder是通过手写数字数据集训练的,那我们就继续沿用,拿来做一点改造,实现手写数字的分类
网络架构设计
为了实现比简单的线性分类网络更强的精度,我决定先使用一个Autoencoder来先学习图片的特征。于是这个分类器的构建可以看作是两个阶段的任务,第一个阶段是对掩码图像编码学习还原图像,另外一个阶段是通过预训练编码器编码完整图片信息训练线性分类器。对此我设计的网络结构如下:
图中红色的数据流代表预训练阶段,这个阶段先对输入的图片做随机的掩码,然后用卷积网络对掩码后的图片做卷积,这里面同时调用了残差块,然后再用一个对称的网络对图片做界面,直接解码输出的是还原的图像(这是与MAE设计的一个不同点,原本的MAE输出未被掩码的区域,但是这里用的不是注意力机制,卷积核只能顾及一块的区域,如果只单独输出Patch的话效果会不太好,模型不能很好理解块之间的关系)。还原图像与生成的图像做MSELoss。绿色数据流是分类网络训练阶段,这里调用的是训练好的encoder,其中的参数都是固定的,对encoder的输出展开然后用一个简单的二层线性网络做图片分类预测,这一步分网络是可训练的,这一块利用的是交叉熵损失。
训练和验证
Autoencoder预训练
左图是训练过程中记录的训练和测试的loss,可以发现40epoch过后loss就没有出现明显下降,处于波动状态。右侧图是每过10轮保存一次的还原可视化结果,可以看到后期模型已经可以比较准确还原图片。
微调训练分类网络
使用预训练100个epoch后的encoder做100轮的微调训练。左图是线性分类器训练时的交叉熵,可以看测试的loss到也是20epoch后就趋于平稳。在训练时程序10轮保存一次模型,右侧是每个模型的各项指标。下面的大图是可视化分类器的结果
线性分类器对照实验
制作了一个类似参数量的线性网络,网络直接把每张图片28*28的像素展开然后输入类似架构的三层线性网络(28*28:10:10),输出做预测,训练100个epoch后各项指标对比如下:
Accuracy | Precision | Recall | F1 | 参数量 | |
带预训练编码器 | 0.9296 | 0.9305 | 0.9296 | 0.9296 | 1W |
简单线性模型 | 0.9127 | 0.9126 | 0.9127 | 0.9124 | 7W |