说明:VGG16和CF-VGG11是论文《A 3D Fluorescence Classification and Component Prediction Method Based on VGG Convolutional Neural Network and PARAFAC Analysis Method》使用的两种主要模型。其对应代码仓库提供了实验使用的数据集、平行因子分析结果和CNN模型。论文和代码仓库是本文实验使用的基本材料。
目录
论文摘要
- 三维荧光的研究目前主要采用平行因子分析(PARAFAC)、荧光区域积分(FRI)和主成分分析(PCA)等方法。
- 目前结合卷积神经网络(CNN)的研究也很多,但在CNN与三维荧光分析相结合的方法中,还没有一种方法被认为是最有效的。
- 本文在已有研究基础上,从实际环境中采集了一些样品进行三维荧光数据的测量,并从互联网中获得了一批公开数据集。
- 首先对数据进行预处理(包括PARAFAC分析和CNN数据集生成两步),然后提出了基于VGG16和VGG11卷积神经网络的三维荧光分类方法和分量拟合方法。
- 使用VGG16网络对三维荧光数据进行分类,训练准确率为99.6%(与PCA + SVM方法同样准确)。
- 对于分量图拟合网络,我们综合比较了改进的LeNet网络、改进的AlexNet网络和改进的VGG11网络,PCA + SVM改进的VGG11网络。
- 在改进的VGG11网络训练中,我们使用MSE损失函数和余弦相似度来判断模型的优劣,网络训练的MSE损失达到4.6×10−4,训练结果的余弦相似度达到0.99。(由此可见,)网络性能非常出色。
- 实验表明,CNN在三维荧光分析中具有很大的应用价值。
数据集信息
以下表格中的Samples,Number,Train,Validate,Test,Total Samples after Expansion
列来自论文的Table 3。
Samples | Number | Train | Validate | VGG16/main | VGG11/train | Test | VGG11/test | VGG16/test | Total Samples after Expansion |
---|---|---|---|---|---|---|---|---|---|
FU | 45 | 27 | 9 | 35 | 35 | 9 | 7 | 35 | 135 |
F | 105 | 63 | 21 | 81 | 81 | 21 | 21 | 81 | 315 |
P | 206 | 124 | 41 | 161 | 161 | 41 | 42 | 161 | 618 |
PU | 60 | 36 | 12 | 45 | 45 | 12 | 12 | 45 | 180 |
论文中的数据扩充说明:在实际训练过程中,我们通过色域失真和镜像翻转来扩展图像。
表格分析:Train+Validate与2个网络的训练集大小相近,VGG16的测试集扩充了4倍,Total Samples after Expansion=3*Number。
环境配置
- 安装CUDA 12.1
- 安装cuDnn
- 新建环境:
conda create -n 3deem python=3.10
- 安装
torch-2.2.1+cu121-cp310-cp310-win_amd64.whl
- 在虚拟环境中安装matplotlib,opencv
pip3 install torchvision --index-url https://download.pytorch.org/whl/cu121
分类实验(工作目录:代码仓库/VGG16)
修改代码
- 新增annotation_generator.py
import os
from utils.utils import get_classes
classes_path = 'model_data/cls_classes.txt'
img_root_path = 'datasets/main/'
txt_path = 'model_data/cls_train.txt'
assert os.path.exists(img_root_path)
txt = open(txt_path,'w')
class_names, num_classes = get_classes(classes_path)
for tag_index in range(0,num_classes):
class_name = class_names[tag_index]
img_path = img_root_path + class_name + '/'
files = os.listdir(img_path)
for img_file in files:
line = str(tag_index) + ';' + img_path + img_file + '\n'
txt.write(line)
-
nets/mobilenet.py、nets/resnet50.py、nets/vgg16.py
torchvision.models.utils → torch.hub
-
修改vit.py
# 在第9行插入以下4行
from torch.hub import load_state_dict_from_url
model_urls = {
'vit': 'https://download.pytorch.org/models/vit_b_16-c867db91.pth',
}
# 修改vit函数如下
def vit(input_shape=[224