目录
1、训练和测试代码
使用时,将训练和测试代码路径修改,并将输出类别修改成需要的类别即可。我这里为6分类。
数据集准备格式:
train,test下分别有6个文件夹:0 1 2 3 4 5。文件夹名为类别名。
代码如下:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
plt.switch_backend('agg')
def loadtraindata():
#path = r"/mnt/nas/cv_data/imagequality/waterloo_de20_all/train"
path = r"/mnt/nas/cv_data/imagequality/testiq/train"
trainset = torchvision.datasets.ImageFolder(path,
transform=transforms.Compose(
[transforms.Resize((32, 32)),
transforms.CenterCrop(32),
transforms.ToTensor()])
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
return trainloader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 6)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
classes = ('0','1','2','3','4','5')
def loadtestdata():
#path = r"/mnt/nas/cv_data/imagequality/waterloo_de20_all/test"
path = r"/mnt/nas/cv_data/imagequality/testiq/test"
testset = torchvision.datasets.ImageFolder(path,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()])
)
testloader = torch.utils.data.DataLoader(testset, batch_size=25,shuffle=True, num_workers=2)
return testloader
def trainandsave():
trainloader = loadtraindata()
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# train
for epoch in range(5):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs
inputs, labels = data
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
#running_loss += loss.data[0]
running_loss += loss.item()
if i % 200 == 199:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('Finished Training')
torch.save(net, 'net.pkl')
torch.save(net.state_dict(), 'net_params.pkl')
def reload_net():
trainednet = torch.load('net.pkl')
return trainednet
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def test():
testloader = loadtestdata()
net = reload_net()
dataiter = iter(testloader)
images, labels = dataiter.next() #
imshow(torchvision.utils.make_grid(images, nrow=5))
print('GroundTruth: ', " ".join('%5s' % classes[labels[j]] for j in range(2)))
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, 1)
print('Predicted: ', " ".join('%5s' % classes[predicted[j]] for j in range(2)))
trainandsave()
test()
2、遇到的问题
(1)问题1
raise notImplementedError
解决:
.py文件中空格/缩放有问题
(2)问题2
AttributeError: 'NoneType' object has no attribute 'log_softmax'
解决:
检查代码发现,forward没有return
(3)问题3
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
解决:应该是版本升级的问题
将loss.data[0] 改为loss.item()
(4)问题4
raise RuntimeError('Invalid DISPLAY variable')
RuntimeError: Invalid DISPLAY variable
解决:
matplotlib的默认backend是TkAgg,而FltkAgg, GTK, GTKAgg, GTKCairo, TkAgg , Wx or WxAgg这几个backend都要求有GUI图形界面的,我运行的linux环境是没有图形界面的,所以报错。
改成:指定不需要GUI的backend(Agg, Cairo, PS, PDF or SVG)
import matplotlib.pyplot as plt
plt.switch_backend('agg')
参考:https://www.cnblogs.com/bymo/p/7447409.html
(5)问题5
print('GroundTruth: ', " ".join('%5s' % classes[labels[j]] for j in range(25)))
IndexError: index 12 is out of bounds for dimension 0 with size 12
解决:
测试展示的index越界,range()的入参要小于测试图片个数。
3、结果展示
测试了1张图片,预测正确。
生成的模型文件如下:
参考
代码:https://blog.csdn.net/a738833592/article/details/80900250
报错:https://blog.csdn.net/terry_zeng/article/details/25985419