猫狗识别基于Pytroch
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
import torchvision.transforms as transform
import matplotlib.pyplot as plt
file_path = "CAT/img"
trans = transform.Compose([
transform.ToTensor(), # 归一化并将HWC转换为CHW
transform.Normalize((0.5,), (0.5,)) # 做均值为0.5, 标准差为0.5的标准化
])
class CatDogDataset(Dataset):
"""整理数据集"""
def __init__(self, file_path, is_training=True):
super(CatDogDataset, self).__init__()
# 定义数据列表,装载图片路径和标签的元组
self.data = []
for path in os.listdir(file_path):
full_path = os.path.join(file_path, path)
label, _, _ = path.split(".") # 取出标签
self.data.append((full_path, label))
# 切分训练集和测试集
if is_training:
self.data = [self.data[i] for i in range(len(self.data)) if i < 5000 or i >= 7000]
else:
self.data = [self.data[i] for i in range(len(self.data)) if i >= 5000 and i < 7000]
def __len__(self):
return len(self.data)
def __getitem__(self, item):
full_path, label = self.data[item]
# 读出图片数据并归一化
img = cv2.imread(full_path)
img_tensor = trans(img) # HWC转CHW并归一化、标准化
# label one hot 编码
one_hot = np.zeros(2)
one_hot[int(label)] = 1
label = int(label)
# 将需要的数据转换为tensor
label_tensor = torch.tensor(label, dtype=torch.float32)
one_hot_tensor = torch.tensor(one_hot, dtype=torch.float32)
#print(one_hot_tensor)
return img_tensor, label_tensor, one_hot_tensor
train_data = CatDogDataset(file_path, True)
test_data = CatDogDataset(file_path, False)
#打印训练集与测试集内图片数量(可以删掉不用)
print(len(train_data))
print(len(test_data))
结果:
10000
2000
# trainloader其实是一个比较重要的东西,我们后面就是通过trainloader把数据传入网
# 络,当然这里的trainloader其实是个变量名,可以随便取,重点是他是由后面的
# torch.utils.data.DataLoader()定义的,这个东西来源于torch.utils.data模块,
# 网页链接http://pytorch.org/docs/0.3.0/data.html
train_loader = torch.utils.data.DataLoader(train_data,batch_size=20,
shuffle=True,drop_last=True)
#和上面一样
test_loader = torch.utils.data.DataLoader(test_data,batch_size=20,
shuffle=True,drop_last=True)
#数据可视化(可以删掉)
data_iter = iter(train_loader)
print(next(data_iter))
结果:
第一个·tensor:图片
第二个tensr:label
第三个tensor:onehot编码
[tensor([[[[ 0.5373, 0.5294, 0.5059, ..., -0.9765, -0.9922, -1.0000],
[ 0.5843, 0.5529, 0.5294, ..., -0.9059, -0.9137, -0.9294],
[ 0.6078, 0.5765, 0.5529, ..., -0.7098, -0.7176, -0.7255],
...,
[-0.3412, -0.3255, -0.3098, ..., -0.9765, -0.9765, -0.9686],
[-0.3412, -0.3333, -0.3098, ..., -0.9765, -0.9765, -0.9686],
[-0.3490, -0.3333, -0.3176, ..., -0.9765, -0.9765, -0.9686]],
[[ 0.2392, 0.2314, 0.2157, ..., -0.9294, -0.9451, -0.9529],
[ 0.2627, 0.2549, 0.2314, ..., -0.8588, -0.8667, -0.8824],
[ 0.2784, 0.2706, 0.2392, ..., -0.6706, -0.6784, -0.6863],
...,
[-0.4431, -0.4275, -0.4118, ..., -0.9608, -0.9608, -0.9529],
[-0.4431, -0.4353, -0.4118, ..., -0.9608, -0.9608, -0.9529],
[-0.4510, -0.4353, -0.4196, ..., -0.9608, -0.9608, -0.9529]],
[[ 0.2078, 0.2000, 0.1843, ..., -0.8431, -0.8588, -0.8667],
[ 0.2392, 0.2157, 0.2000, ..., -0.7725, -0.7804, -0.7961],
[ 0.2392, 0.2078, 0.2000, ..., -0.6000, -0.6078, -0.6157],
...,
[-0.3804, -0.3647, -0.3490, ..., -0.9608, -0.9608, -0.9529],
[-0.3804, -0.3725, -0.3490, ..., -0.9608, -0.9608, -0.9529],
[-0.3882, -0.3725, -0.3569, ..., -0.9608, -0.9608, -0.9529]]],
[[[-0.7882, -0.7333, -0.6784, ..., -0.5843, -0.5608, -0.5451],
[-0.8118, -0.7569, -0.7255, ..., -0.6078, -0.5765, -0.5608],
[-0.8353, -0.7647, -0.7333, ..., -0.6235, -0.5922, -0.5765],
...,
[-0.1137, -0.1294, -0.1529, ..., -0.3176, -0.3255, -0.3569],
[-0.1216, -0.1137, -0.1216, ..., -0.2941, -0.3098, -0.3412],
[-0.0902, -0.0902, -0.1059, ..., -0.2784, -0.3020, -0.3569]],
[[-0.5608, -0.5137, -0.4902, ..., -0.5686, -0.5451, -0.5294],
[-0.5843, -0.5373, -0.5373, ..., -0.5922, -0.5608, -0.5451],
[-0.6078, -0.5451, -0.5451, ..., -0.6078, -0.5765, -0.5608],
...,
[ 0.0902, 0.0588, 0.0275, ..., -0.7647, -0.7725, -0.8039],
[ 0.0824, 0.0745, 0.0588, ..., -0.7412, -0.7569, -0.7882],
[ 0.1137, 0.0980, 0.0745, ..., -0.7255, -0.7490, -0.8039]],
[[-0.2706, -0.2392, -0.2549, ..., -0.6157, -0.5922, -0.5765],
[-0.2941, -0.2627, -0.3020, ..., -0.6392, -0.6078, -0.5922],
[-0.3176, -0.2706, -0.3098, ..., -0.6549, -0.6235, -0.6078],
...,
[ 0.4510, 0.4196, 0.3804, ..., -0.9294, -0.9373, -0.9686],
[ 0.4510, 0.4353, 0.4118, ..., -0.9059, -0.9216, -0.9529],
[ 0.4824, 0.4588, 0.4275, ..., -0.8902, -0.9137, -0.9686]]],
[[[-0.5765, -0.4902, -0.3882, ..., -0.7176, -0.8431, -0.9059],
[-0.4824, -0.4980, -0.4980, ..., -0.8118, -0.9216, -0.9373],
[-0.5529, -0.4980, -0.4745, ..., -0.8902, -0.9922, -0.9294],
...,
[ 0.3647, 0.2941, 0.3255, ..., 0.9686, 0.9686, 0.9686],
[ 0.3569, 0.3255, 0.3569, ..., 0.9686, 0.9686, 0.9686],
[ 0.2549, 0.3098, 0.4039, ..., 0.9686, 0.9686, 0.9686]],
[[-0.6235, -0.5373, -0.4353, ..., -0.7412, -0.8667, -0.9294],
[-0.5529, -0.5686, -0.5686, ..., -0.8118, -0.9216, -0.9373],
[-0.6392, -0.5843, -0.5608, ..., -0.8431, -0.9451, -0.8824],
...,
[ 0.3961, 0.3333, 0.3333, ..., 0.9686, 0.9686, 0.9686],
[ 0.3882, 0.3647, 0.3647, ..., 0.9686, 0.9686, 0.9686],
[ 0.2863, 0.3490, 0.4118, ..., 0.9686, 0.9686, 0.9686]],
[[-0.7098, -0.6235, -0.5216, ..., -0.6235, -0.7490, -0.8118],
[-0.6314, -0.6471, -0.6471, ..., -0.7020, -0.8118, -0.8275],
[-0.7020, -0.6471, -0.6235, ..., -0.7569, -0.8588, -0.7961],
...,
[ 0.3490, 0.2627, 0.2627, ..., 0.9686, 0.9686, 0.9686],
[ 0.3412, 0.2941, 0.2941, ..., 0.9686, 0.9686, 0.9686],
[ 0.2392, 0.2784, 0.3412, ..., 0.9686, 0.9686, 0.9686]]],
...,
[[[ 0.4431, 0.4039, 0.3020, ..., 0.3412, 0.3020, 0.2784],
[ 0.4118, 0.4353, 0.3569, ..., 0.4510, 0.3882, 0.3412],
[ 0.4510, 0.5137, 0.4510, ..., 0.6392, 0.5765, 0.5294],
...,
[-0.2392, -0.1765, -0.0667, ..., -0.0510, -0.1451, 0.0510],
[-0.3176, -0.2471, -0.1373, ..., -0.0353, -0.1922, -0.0588],
[-0.3725, -0.3020, -0.1922, ..., -0.1216, -0.3412, -0.2784]],
[[ 0.0039, -0.0275, -0.1059, ..., 0.0588, 0.0196, -0.0039],
[-0.0196, 0.0039, -0.0510, ..., 0.1765, 0.1059, 0.0667],
[ 0.0431, 0.1137, 0.0667, ..., 0.3882, 0.3176, 0.2784],
...,
[-0.5216, -0.4588, -0.3490, ..., -0.2157, -0.3098, -0.1137],
[-0.6000, -0.5216, -0.4118, ..., -0.2000, -0.3569, -0.2235],
[-0.6471, -0.5686, -0.4667, ..., -0.2863, -0.5059, -0.4431]],
[[-0.2706, -0.3333, -0.4667, ..., -0.1765, -0.2157, -0.2392],
[-0.3098, -0.3020, -0.4196, ..., -0.0824, -0.1294, -0.1922],
[-0.2706, -0.2235, -0.3255, ..., 0.1137, 0.0588, 0.0039],
...,
[-0.5686, -0.5216, -0.4275, ..., -0.3882, -0.4824, -0.2863],
[-0.6627, -0.6000, -0.5137, ..., -0.3725, -0.5294, -0.3961],
[-0.7255, -0.6706, -0.5686, ..., -0.4588, -0.6784, -0.6157]]],
[[[-0.1686, -0.2235, -0.2784, ..., -0.4588, -0.5686, -0.4039],
[-0.1765, -0.3333, -0.3020, ..., -0.4196, -0.4667, -0.3882],
[-0.2078, -0.3961, -0.3098, ..., -0.4039, -0.3804, -0.4039],
...,
[-0.4667, -0.5216, -0.4353, ..., -0.3020, -0.3647, -0.4118],
[-0.4980, -0.4980, -0.4353, ..., -0.4353, -0.4431, -0.5059],
[-0.5529, -0.4667, -0.4353, ..., -0.3098, -0.2784, -0.5059]],
[[ 0.1451, 0.0902, 0.0039, ..., 0.0431, -0.0510, 0.1137],
[ 0.1294, -0.0275, -0.0118, ..., 0.0588, 0.0118, 0.1137],
[ 0.1059, -0.0902, -0.0196, ..., 0.0431, 0.0667, 0.0431],
...,
[-0.3020, -0.3804, -0.3255, ..., 0.0667, 0.0196, -0.0275],
[-0.3333, -0.3569, -0.3255, ..., -0.0667, -0.0588, -0.1059],
[-0.3882, -0.3255, -0.3255, ..., 0.0745, 0.1059, -0.1059]],
[[-0.1843, -0.2314, -0.2784, ..., -0.4667, -0.5608, -0.3961],
[-0.2314, -0.3804, -0.3255, ..., -0.4431, -0.4902, -0.3961],
[-0.3255, -0.4980, -0.3804, ..., -0.4353, -0.4118, -0.4353],
...,
[-0.1294, -0.2000, -0.1529, ..., -0.3176, -0.3725, -0.4196],
[-0.1608, -0.1765, -0.1529, ..., -0.4510, -0.4510, -0.4980],
[-0.2157, -0.1451, -0.1529, ..., -0.3176, -0.2863, -0.4980]]],
[[[ 0.9059, 0.7647, 0.8824, ..., 0.0510, 0.0118, 0.0353],
[ 0.8510, 0.9137, 0.8118, ..., 0.0275, -0.0275, -0.0431],
[ 0.7569, 0.8824, 0.7490, ..., -0.0039, -0.0510, -0.0902],
...,
[-0.0039, -0.1137, -0.0902, ..., 0.2549, 0.1843, 0.1843],
[-0.2000, -0.2392, -0.1765, ..., 0.3490, 0.2627, 0.2078],
[-0.1922, -0.1843, -0.1373, ..., 0.3098, 0.2235, 0.1216]],
[[ 0.8275, 0.7020, 0.8118, ..., 0.1216, 0.0824, 0.1059],
[ 0.7725, 0.8510, 0.7412, ..., 0.0980, 0.0431, 0.0275],
[ 0.6784, 0.8196, 0.6784, ..., 0.0667, 0.0196, -0.0196],
...,
[ 0.0667, -0.0431, -0.0196, ..., 0.2863, 0.2235, 0.2235],
[-0.1294, -0.1686, -0.1059, ..., 0.3804, 0.3020, 0.2627],
[-0.1216, -0.1137, -0.0667, ..., 0.3569, 0.2627, 0.1765]],
[[ 0.7725, 0.6471, 0.7804, ..., 0.1922, 0.1529, 0.1765],
[ 0.7176, 0.7961, 0.7098, ..., 0.1686, 0.1137, 0.0980],
[ 0.6235, 0.7647, 0.6471, ..., 0.1373, 0.0902, 0.0510],
...,
[ 0.1765, 0.0667, 0.0902, ..., 0.3725, 0.3412, 0.3412],
[-0.0196, -0.0588, 0.0039, ..., 0.4667, 0.4196, 0.3804],
[-0.0118, -0.0039, 0.0431, ..., 0.4431, 0.3804, 0.2941]]]]), tensor([0., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
1., 0.]), tensor([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.],
[1., 0.],
[1., 0.],
[0., 1.],
[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.],
[0., 1.],
[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.],
[0., 1.],
[1., 0.]])]
#(可以删掉)
oneimg,label,one_hot = train_data[0]
print(len(oneimg))
print(len(oneimg[0][0]))#显示一张图片的大小
oneimg = oneimg.numpy().transpose(1,2,0)
std = [0.5]
mean = [0.5]
oneimg = oneimg * std + mean
plt.imshow(oneimg)
plt.show()
结果:
3(图片通道数)
100(图片大小为100*100)
#.定义一个CNN网络
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Conv2d(3,6,kernel_size=23,stride=1,padding=1)#卷积层1
self.pool = nn.MaxPool2d(2,2)#池化层
self.conv2 = nn.Conv2d(6,16,kernel_size=23,stride=1,padding=1)#卷积层2
self.fc1 = nn.Linear(16*10*10,1024)#两个池化,所以是7*7而不是14*14
self.fc2 = nn.Linear(1024,512)
self.fc3 = nn.Linear(512,2)
# self.dp = nn.Dropout(p=0.5)
def forward(self,x):# 这里定义前向传播的方法
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 10* 10)#将数据平整为一维的
x = F.relu(self.fc1(x))#激活函数
# x = self.fc3(x)
# self.dp(x)
x = F.relu(self.fc2(x))
x = self.fc3(x)
# x = F.log_softmax(x,dim=1) NLLLoss()才需要,交叉熵不需要
return x
# .view( )是一个tensor的方法,使得tensor改变size但是元素的总数是不变的。
# 第一个参数-1是说这个参数由另一个参数确定, 比如矩阵在元素总数一定的情况下,确定列数就能确定行数。
# 那么为什么这里只关心列数不关心行数呢,因为马上就要进入全连接层了,而全连接层说白了就是矩阵乘法,
# 你会发现第一个全连接层的首参数是16*5*5,所以要保证能够相乘,在矩阵乘法之前就要把x调到正确的size
net = CNN()
import torch.optim as optim
criterion = nn.CrossEntropyLoss()#同样是用到了神经网络工具箱 nn 中的交叉熵损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#也可以选择Adam优化方法
# optimizer = torch.optim.Adam(net.parameters(),lr=1e-2)
train_accs = []
train_loss = []
test_accs = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
for epoch in range(10):
running_loss = 0.0 #定义一个变量方便我们对loss进行输出
for i,data in enumerate(train_loader,0):#0是下标起始位置默认为0
# 这里我们遇到了第一步中出现的trailoader,代码传入数据
# enumerate是python的内置函数,既获得索引也获得数据,
# data 的格式[[inputs, labels]]
# inputs,labels = data
inputs,labels = data[0].to(device), data[1].to(device)
#初始为0,清除上个batch的梯度信息
optimizer.zero_grad()
#前向+后向+优化
outputs = net(inputs)
loss = criterion(outputs,labels.long())
loss.backward()# loss进行反向传播,
optimizer.step() # 当执行反向传播之后,把优化器的参数进行更新,以便进行下一轮
# loss 的输出,每个一百个batch输出,平均的loss
running_loss += loss.item()
if i%100 == 99:
print('[%d,%5d] loss :%.3f' %
(epoch+1,i+1,running_loss/100))
running_loss = 0.0
train_loss.append(loss.item())
# 训练曲线的绘制 一个batch中的准确率
correct = 0
total = 0
_, predicted = torch.max(outputs.data, 1)
total = labels.size(0)# labels 的长度
correct = (predicted == labels).sum().item() # 预测正确的数目
train_accs.append(100*correct/total)
print('Finished Training')
结果:
[1, 100] loss :0.693
[1, 200] loss :0.693
[1, 300] loss :0.691
[1, 400] loss :0.688
[1, 500] loss :0.682
[2, 100] loss :0.684
[2, 200] loss :0.677
[2, 300] loss :0.678
[2, 400] loss :0.657
[2, 500] loss :0.661
[3, 100] loss :0.641
[3, 200] loss :0.647
[3, 300] loss :0.638
[3, 400] loss :0.639
[3, 500] loss :0.633
[4, 100] loss :0.610
[4, 200] loss :0.620
[4, 300] loss :0.614
[4, 400] loss :0.615
[4, 500] loss :0.601
[5, 100] loss :0.601
[5, 200] loss :0.606
[5, 300] loss :0.602
[5, 400] loss :0.598
[5, 500] loss :0.593
[6, 100] loss :0.554
[6, 200] loss :0.580
[6, 300] loss :0.573
[6, 400] loss :0.591
[6, 500] loss :0.575
[7, 100] loss :0.553
[7, 200] loss :0.564
[7, 300] loss :0.557
[7, 400] loss :0.572
[7, 500] loss :0.544
[8, 100] loss :0.525
[8, 200] loss :0.531
[8, 300] loss :0.547
[8, 400] loss :0.555
[8, 500] loss :0.536
[9, 100] loss :0.520
[9, 200] loss :0.515
[9, 300] loss :0.523
[9, 400] loss :0.509
[9, 500] loss :0.533
[10, 100] loss :0.479
[10, 200] loss :0.492
[10, 300] loss :0.495
[10, 400] loss :0.486
[10, 500] loss :0.518
Finished Training
#计算准确准确率(可以删掉)
def draw_train_process(title,iters,costs,accs,label_cost,lable_acc):
plt.title(title, fontsize=24)
plt.xlabel("iter", fontsize=20)
plt.ylabel("acc(\%)", fontsize=20)
plt.plot(iters, costs,color='red',label=label_cost)
plt.plot(iters, accs,color='green',label=lable_acc)
plt.legend()
plt.grid()
plt.show()
train_iters = range(len(train_accs))
draw_train_process('training',train_iters,train_loss,train_accs,'training loss','training acc')
#(可以删掉)
from torchvision import datasets, transforms,utils
dataiter = iter(test_loader)
images, labels,one_hot = dataiter.next()
# print images
test_img = utils.make_grid(images)
test_img = test_img.numpy().transpose(1,2,0)
std = [0.5,0.5,0.5]
mean = [0.5,0.5,0.5]
test_img = test_img*std+0.5
plt.imshow(test_img)
plt.show()
print('GroundTruth: ', ' '.join('%d' % labels[j] for j in range(20)))
GroundTruth: 0 1 0 1 1 1 1 1 1 0 1 0 0 1 0 0 0 0 0 0