为往圣继绝学
前言
本文主要介绍了如何利用pytorch构建vgg网络,并且用于简单的训练和测试。为了更好地服务新手使用,可能部分内容会比较 啰嗦 ,烦请见谅!另外,若有错误,请大家指出!Thanks♪(・ω・)ノ
目标对象:那些“还对vgg网络不熟悉,但是想要快速上手尝试使用vgg网络或复现某些论文的vgg网络"的读者使用。而对于想要更加深入理解vgg原理的读者可能没有很好的效果。
注意!:本文主要目的是帮助大家快速地利用pytorch工具使用vgg网络,对于相关原理不会做过多介绍,如果读者想要更深入地理解vgg网络请尝试参考其他文章。以下代码用二分类举例 (完整代码整理好后,会发在github上)
一、相关环境
语言环境:Python3.9
IDE:Pycharm 2021.1
GPU:NVIDIA GeForce GTX 1660 Ti
二、使用步骤
1.引入库
代码如下(示例):
import numpy as np
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import models
import torch
import torch as nn
import pandas as pd #用于数据提取,不是必要的库
2.读入数据
读取数据,具体方式请根据自身情况选择:
data_path = './'#请输入文件路径、文件名,以下代码三选一
data = pd.read_excel(data_path)#xlsx文件
data = pd.read_csv(data_path)#csv文件
data = np.loadtxt(data_path, delimiter='\t')#txt文件
以上为一维数据的读取,若读取图片请使用其他的读入方法。
但无论哪一种读取方式,得到的前两个维度应该是[样本数,通道数]
举例:[100,3,128,128]为我们执行下面这条代码得到的维度:其中100为样本数量,3为通道数量,两个128是尺寸。
print(data.shape)#检查数据维度
3.构建网络
1.设备选择:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')#用于选择cpu还是gpu
2.创建库中模型:
vgg16 = models.vgg16().to(device)
vgg16()可更改成其他网络,如下图所示。(“bn” 表示 Batch Normalization,即批量归一化,vgg+bn意为在原始的网络的每个卷积层之后使用批量归一化技术。)
3.加载模型(非必选):
方式请根据手中的模型形式决定
#第一种读取字典(推荐)
vgg16 = models.vgg16(pretrained=False).to(device)
vgg16.load_state_dict(torch.load('model.pth'))
#第二种读取模型
vgg16 = torch.load('model.pth')
4.查看网络:
直接对网络进行输出即可,这里用vgg16演示。
print(vgg16)
演示结果:用于在对网络调整后的观察确认
5.修改网络(增删改冻):
值得留意的是:(三个问题代码都在“改”的部分)
1.features中的第一层卷积神经网络中的3,请根据自己的数据的通道数做修改。
2.classifier中的最后一层Linear的out_features要与自己的想要分类的种类数一致。
3.classifier中的第一层Linear的in_features要与avgpool后得到的矩阵列数相同。(报错时根据所给数据更改即可)
———————————————————————————————————————————————————————————————
以下皆为举例说明,读者可用查看网络的方式来验证,非必要不用执行。
增:
vgg16.features.add_module('name',nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1))#features层末尾增加一个名为name的二维卷积层
vgg16.classifier.add_module('name',nn.Linear(4096, 2))#classifier层末尾增加一个名为name的Linear层
删:
vgg16.features[1] = nn.Sequential()#对features的1位置的网络层取空,即进行删除
改(对应上面三个问题):
#整个替换
vgg16.features[0] = torch.nn.Conv1d(1, 64, kernel_size=3, stride=1, padding=1)
#参数修改
vgg16.classifier[6].out_features=2
vgg16.classifier[0].in_features = 25088
冻:
vgg16.features[0].weight.requires_grad = False#反向传播的时候,不会更新该层的参数
4.训练模型
1.训练前对数据的处理(请根据自身数据情况处理):
#以下X,Y分别对应inputs、labels。为了用于pytoch,需要将numpy转换成tensor
X=torch.tensor(X)
Y=torch.tensor(Y)
#将数据放入GPU训练
X=X.to(device)
Y=Y.to(device)
#X = np.expand_dims(X, 1)#如果数据维度不满足要求,可以该代码来增加维度
data = TensorDataset(X,Y)
data_loader = DataLoader(data,batch_size=8,shuffle=False)#shuffle代表是否乱序,batchsize自己设定即可
2.优化器:
#优化器根据自身情况自行选择即可,不一定是下面几种
optimizer = torch.optim.Adam(vgg16.parameters(), lr=0.001)#adam
optimizer = torch.optim.SGD(vgg16.parameters(), lr=0.001)#SGD
3.损失函数:
#损失函数根据自身情况自行选择即可,不一定是下面几种
loss_func = torch.nn.CrossEntropyLoss().to(device)
loss_func = torch.nn.BCELoss().to(device)
4.训练及模型保存:
#start train
epoch = 1
for i in range(epoch):
loss_num = 0.0
right=0.0
total=0.0
best_acc=0.0
for step ,data in enumerate(data_loader):
inputs,labels =data
out = vgg16(inputs)
labels=labels.to(torch.int64)#损失函数对于labels要求int64形式时用该代码修改
loss= loss_func(out,labels)
right += ((out.argmax(1) == labels).sum()).item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
total += labels.size(0)
acc=right/total
if acc > best_acc:#保存模型
best_acc = acc
torch.save(net.state_dict(), save_path) # save_path请自己设置
print('Epoch:',i+1,' correct:',right,' acc:',right/total,' loss: ',running_loss)
#train done!
5.测试数据
vgg16.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in data_loader:#这里的data_loader是测试数据!!
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy on test set: {100 * correct / total}%","correct: ",correct)
三、最后提醒
1.读取数据请按照的部分,请按照自身数据的情况选择读取方式,只要最后能录入data_loader即可。
2.如果是处理一维数据,建议修改卷积层、池化层等,方式可参考修改网络的部分代码。
3.修改网络中提到的三个注意事项,对于读者训练自己的数据是非常必要修改的。
4.对样本、标签数据的数据类型修改是为了更好地适用于损失函数的使用,所以读者可以根据报错情况来修改自己的数据类型,不必按照文中步骤。
5.对于想要复现一些论文中vgg16结构的读者,可以重点阅读文中构建网络的部分。
6.文中损失函数和优化器的代码只是为了提供模板,具体选择请读者按照自身情况抉择。
7.待续…(等待大家的批评)
总结
本文代码内容演示全部都是以vgg16网络为基础。
感谢论坛大神们对于vgg网络的使用进行经验分享!
新手第一次写文,多多包涵,若有帮助烦请看官点个赞再走吧!