猫狗识别基于Pytroch

猫狗识别基于Pytroch

import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
import torchvision.transforms as transform
import matplotlib.pyplot as plt
 
file_path = "CAT/img"
 
trans = transform.Compose([
    transform.ToTensor(),  # 归一化并将HWC转换为CHW
    transform.Normalize((0.5,), (0.5,))  # 做均值为0.5, 标准差为0.5的标准化
])
 
 
class CatDogDataset(Dataset):
    """整理数据集"""
 
    def __init__(self, file_path, is_training=True):
        super(CatDogDataset, self).__init__()
        # 定义数据列表,装载图片路径和标签的元组
        self.data = []
        for path in os.listdir(file_path):
            full_path = os.path.join(file_path, path)
            label, _, _ = path.split(".") # 取出标签
            self.data.append((full_path, label))
        # 切分训练集和测试集
        if is_training:
            self.data = [self.data[i] for i in range(len(self.data)) if i < 5000 or i >= 7000]
        else:
            self.data = [self.data[i] for i in range(len(self.data)) if i >= 5000 and i < 7000]
 
    def __len__(self):
        return len(self.data)
 
    def __getitem__(self, item):
        full_path, label = self.data[item]
        # 读出图片数据并归一化
        img = cv2.imread(full_path)
        img_tensor = trans(img)  # HWC转CHW并归一化、标准化
 
        # label one hot 编码
        one_hot = np.zeros(2)
        one_hot[int(label)] = 1
        label = int(label)
 
        # 将需要的数据转换为tensor
        label_tensor = torch.tensor(label, dtype=torch.float32)
        one_hot_tensor = torch.tensor(one_hot, dtype=torch.float32)
        #print(one_hot_tensor)
 
        return img_tensor, label_tensor, one_hot_tensor
train_data = CatDogDataset(file_path, True)
test_data = CatDogDataset(file_path, False)
#打印训练集与测试集内图片数量(可以删掉不用)
print(len(train_data))
print(len(test_data))
结果:
10000
2000
# trainloader其实是一个比较重要的东西,我们后面就是通过trainloader把数据传入网
    # 络,当然这里的trainloader其实是个变量名,可以随便取,重点是他是由后面的
    # torch.utils.data.DataLoader()定义的,这个东西来源于torch.utils.data模块,
    #  网页链接http://pytorch.org/docs/0.3.0/data.html
train_loader = torch.utils.data.DataLoader(train_data,batch_size=20,
                                          shuffle=True,drop_last=True)
#和上面一样
test_loader = torch.utils.data.DataLoader(test_data,batch_size=20,
                                          shuffle=True,drop_last=True)
#数据可视化(可以删掉)
data_iter = iter(train_loader)
print(next(data_iter))
结果:
第一个·tensor:图片
第二个tensr:label
第三个tensor:onehot编码
[tensor([[[[ 0.5373,  0.5294,  0.5059,  ..., -0.9765, -0.9922, -1.0000],
          [ 0.5843,  0.5529,  0.5294,  ..., -0.9059, -0.9137, -0.9294],
          [ 0.6078,  0.5765,  0.5529,  ..., -0.7098, -0.7176, -0.7255],
          ...,
          [-0.3412, -0.3255, -0.3098,  ..., -0.9765, -0.9765, -0.9686],
          [-0.3412, -0.3333, -0.3098,  ..., -0.9765, -0.9765, -0.9686],
          [-0.3490, -0.3333, -0.3176,  ..., -0.9765, -0.9765, -0.9686]],

         [[ 0.2392,  0.2314,  0.2157,  ..., -0.9294, -0.9451, -0.9529],
          [ 0.2627,  0.2549,  0.2314,  ..., -0.8588, -0.8667, -0.8824],
          [ 0.2784,  0.2706,  0.2392,  ..., -0.6706, -0.6784, -0.6863],
          ...,
          [-0.4431, -0.4275, -0.4118,  ..., -0.9608, -0.9608, -0.9529],
          [-0.4431, -0.4353, -0.4118,  ..., -0.9608, -0.9608, -0.9529],
          [-0.4510, -0.4353, -0.4196,  ..., -0.9608, -0.9608, -0.9529]],

         [[ 0.2078,  0.2000,  0.1843,  ..., -0.8431, -0.8588, -0.8667],
          [ 0.2392,  0.2157,  0.2000,  ..., -0.7725, -0.7804, -0.7961],
          [ 0.2392,  0.2078,  0.2000,  ..., -0.6000, -0.6078, -0.6157],
          ...,
          [-0.3804, -0.3647, -0.3490,  ..., -0.9608, -0.9608, -0.9529],
          [-0.3804, -0.3725, -0.3490,  ..., -0.9608, -0.9608, -0.9529],
          [-0.3882, -0.3725, -0.3569,  ..., -0.9608, -0.9608, -0.9529]]],


        [[[-0.7882, -0.7333, -0.6784,  ..., -0.5843, -0.5608, -0.5451],
          [-0.8118, -0.7569, -0.7255,  ..., -0.6078, -0.5765, -0.5608],
          [-0.8353, -0.7647, -0.7333,  ..., -0.6235, -0.5922, -0.5765],
          ...,
          [-0.1137, -0.1294, -0.1529,  ..., -0.3176, -0.3255, -0.3569],
          [-0.1216, -0.1137, -0.1216,  ..., -0.2941, -0.3098, -0.3412],
          [-0.0902, -0.0902, -0.1059,  ..., -0.2784, -0.3020, -0.3569]],

         [[-0.5608, -0.5137, -0.4902,  ..., -0.5686, -0.5451, -0.5294],
          [-0.5843, -0.5373, -0.5373,  ..., -0.5922, -0.5608, -0.5451],
          [-0.6078, -0.5451, -0.5451,  ..., -0.6078, -0.5765, -0.5608],
          ...,
          [ 0.0902,  0.0588,  0.0275,  ..., -0.7647, -0.7725, -0.8039],
          [ 0.0824,  0.0745,  0.0588,  ..., -0.7412, -0.7569, -0.7882],
          [ 0.1137,  0.0980,  0.0745,  ..., -0.7255, -0.7490, -0.8039]],

         [[-0.2706, -0.2392, -0.2549,  ..., -0.6157, -0.5922, -0.5765],
          [-0.2941, -0.2627, -0.3020,  ..., -0.6392, -0.6078, -0.5922],
          [-0.3176, -0.2706, -0.3098,  ..., -0.6549, -0.6235, -0.6078],
          ...,
          [ 0.4510,  0.4196,  0.3804,  ..., -0.9294, -0.9373, -0.9686],
          [ 0.4510,  0.4353,  0.4118,  ..., -0.9059, -0.9216, -0.9529],
          [ 0.4824,  0.4588,  0.4275,  ..., -0.8902, -0.9137, -0.9686]]],


        [[[-0.5765, -0.4902, -0.3882,  ..., -0.7176, -0.8431, -0.9059],
          [-0.4824, -0.4980, -0.4980,  ..., -0.8118, -0.9216, -0.9373],
          [-0.5529, -0.4980, -0.4745,  ..., -0.8902, -0.9922, -0.9294],
          ...,
          [ 0.3647,  0.2941,  0.3255,  ...,  0.9686,  0.9686,  0.9686],
          [ 0.3569,  0.3255,  0.3569,  ...,  0.9686,  0.9686,  0.9686],
          [ 0.2549,  0.3098,  0.4039,  ...,  0.9686,  0.9686,  0.9686]],

         [[-0.6235, -0.5373, -0.4353,  ..., -0.7412, -0.8667, -0.9294],
          [-0.5529, -0.5686, -0.5686,  ..., -0.8118, -0.9216, -0.9373],
          [-0.6392, -0.5843, -0.5608,  ..., -0.8431, -0.9451, -0.8824],
          ...,
          [ 0.3961,  0.3333,  0.3333,  ...,  0.9686,  0.9686,  0.9686],
          [ 0.3882,  0.3647,  0.3647,  ...,  0.9686,  0.9686,  0.9686],
          [ 0.2863,  0.3490,  0.4118,  ...,  0.9686,  0.9686,  0.9686]],

         [[-0.7098, -0.6235, -0.5216,  ..., -0.6235, -0.7490, -0.8118],
          [-0.6314, -0.6471, -0.6471,  ..., -0.7020, -0.8118, -0.8275],
          [-0.7020, -0.6471, -0.6235,  ..., -0.7569, -0.8588, -0.7961],
          ...,
          [ 0.3490,  0.2627,  0.2627,  ...,  0.9686,  0.9686,  0.9686],
          [ 0.3412,  0.2941,  0.2941,  ...,  0.9686,  0.9686,  0.9686],
          [ 0.2392,  0.2784,  0.3412,  ...,  0.9686,  0.9686,  0.9686]]],


        ...,


        [[[ 0.4431,  0.4039,  0.3020,  ...,  0.3412,  0.3020,  0.2784],
          [ 0.4118,  0.4353,  0.3569,  ...,  0.4510,  0.3882,  0.3412],
          [ 0.4510,  0.5137,  0.4510,  ...,  0.6392,  0.5765,  0.5294],
          ...,
          [-0.2392, -0.1765, -0.0667,  ..., -0.0510, -0.1451,  0.0510],
          [-0.3176, -0.2471, -0.1373,  ..., -0.0353, -0.1922, -0.0588],
          [-0.3725, -0.3020, -0.1922,  ..., -0.1216, -0.3412, -0.2784]],

         [[ 0.0039, -0.0275, -0.1059,  ...,  0.0588,  0.0196, -0.0039],
          [-0.0196,  0.0039, -0.0510,  ...,  0.1765,  0.1059,  0.0667],
          [ 0.0431,  0.1137,  0.0667,  ...,  0.3882,  0.3176,  0.2784],
          ...,
          [-0.5216, -0.4588, -0.3490,  ..., -0.2157, -0.3098, -0.1137],
          [-0.6000, -0.5216, -0.4118,  ..., -0.2000, -0.3569, -0.2235],
          [-0.6471, -0.5686, -0.4667,  ..., -0.2863, -0.5059, -0.4431]],

         [[-0.2706, -0.3333, -0.4667,  ..., -0.1765, -0.2157, -0.2392],
          [-0.3098, -0.3020, -0.4196,  ..., -0.0824, -0.1294, -0.1922],
          [-0.2706, -0.2235, -0.3255,  ...,  0.1137,  0.0588,  0.0039],
          ...,
          [-0.5686, -0.5216, -0.4275,  ..., -0.3882, -0.4824, -0.2863],
          [-0.6627, -0.6000, -0.5137,  ..., -0.3725, -0.5294, -0.3961],
          [-0.7255, -0.6706, -0.5686,  ..., -0.4588, -0.6784, -0.6157]]],


        [[[-0.1686, -0.2235, -0.2784,  ..., -0.4588, -0.5686, -0.4039],
          [-0.1765, -0.3333, -0.3020,  ..., -0.4196, -0.4667, -0.3882],
          [-0.2078, -0.3961, -0.3098,  ..., -0.4039, -0.3804, -0.4039],
          ...,
          [-0.4667, -0.5216, -0.4353,  ..., -0.3020, -0.3647, -0.4118],
          [-0.4980, -0.4980, -0.4353,  ..., -0.4353, -0.4431, -0.5059],
          [-0.5529, -0.4667, -0.4353,  ..., -0.3098, -0.2784, -0.5059]],

         [[ 0.1451,  0.0902,  0.0039,  ...,  0.0431, -0.0510,  0.1137],
          [ 0.1294, -0.0275, -0.0118,  ...,  0.0588,  0.0118,  0.1137],
          [ 0.1059, -0.0902, -0.0196,  ...,  0.0431,  0.0667,  0.0431],
          ...,
          [-0.3020, -0.3804, -0.3255,  ...,  0.0667,  0.0196, -0.0275],
          [-0.3333, -0.3569, -0.3255,  ..., -0.0667, -0.0588, -0.1059],
          [-0.3882, -0.3255, -0.3255,  ...,  0.0745,  0.1059, -0.1059]],

         [[-0.1843, -0.2314, -0.2784,  ..., -0.4667, -0.5608, -0.3961],
          [-0.2314, -0.3804, -0.3255,  ..., -0.4431, -0.4902, -0.3961],
          [-0.3255, -0.4980, -0.3804,  ..., -0.4353, -0.4118, -0.4353],
          ...,
          [-0.1294, -0.2000, -0.1529,  ..., -0.3176, -0.3725, -0.4196],
          [-0.1608, -0.1765, -0.1529,  ..., -0.4510, -0.4510, -0.4980],
          [-0.2157, -0.1451, -0.1529,  ..., -0.3176, -0.2863, -0.4980]]],


        [[[ 0.9059,  0.7647,  0.8824,  ...,  0.0510,  0.0118,  0.0353],
          [ 0.8510,  0.9137,  0.8118,  ...,  0.0275, -0.0275, -0.0431],
          [ 0.7569,  0.8824,  0.7490,  ..., -0.0039, -0.0510, -0.0902],
          ...,
          [-0.0039, -0.1137, -0.0902,  ...,  0.2549,  0.1843,  0.1843],
          [-0.2000, -0.2392, -0.1765,  ...,  0.3490,  0.2627,  0.2078],
          [-0.1922, -0.1843, -0.1373,  ...,  0.3098,  0.2235,  0.1216]],

         [[ 0.8275,  0.7020,  0.8118,  ...,  0.1216,  0.0824,  0.1059],
          [ 0.7725,  0.8510,  0.7412,  ...,  0.0980,  0.0431,  0.0275],
          [ 0.6784,  0.8196,  0.6784,  ...,  0.0667,  0.0196, -0.0196],
          ...,
          [ 0.0667, -0.0431, -0.0196,  ...,  0.2863,  0.2235,  0.2235],
          [-0.1294, -0.1686, -0.1059,  ...,  0.3804,  0.3020,  0.2627],
          [-0.1216, -0.1137, -0.0667,  ...,  0.3569,  0.2627,  0.1765]],

         [[ 0.7725,  0.6471,  0.7804,  ...,  0.1922,  0.1529,  0.1765],
          [ 0.7176,  0.7961,  0.7098,  ...,  0.1686,  0.1137,  0.0980],
          [ 0.6235,  0.7647,  0.6471,  ...,  0.1373,  0.0902,  0.0510],
          ...,
          [ 0.1765,  0.0667,  0.0902,  ...,  0.3725,  0.3412,  0.3412],
          [-0.0196, -0.0588,  0.0039,  ...,  0.4667,  0.4196,  0.3804],
          [-0.0118, -0.0039,  0.0431,  ...,  0.4431,  0.3804,  0.2941]]]]), tensor([0., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
        1., 0.]), tensor([[1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.]])]
#(可以删掉)

oneimg,label,one_hot = train_data[0]
print(len(oneimg))
print(len(oneimg[0][0]))#显示一张图片的大小
oneimg = oneimg.numpy().transpose(1,2,0) 
std = [0.5]
mean = [0.5]
oneimg = oneimg * std + mean
plt.imshow(oneimg)
plt.show()
结果:
3(图片通道数)
100(图片大小为100*100)

在这里插入图片描述

#.定义一个CNN网络
import torch.nn.functional as F
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(3,6,kernel_size=23,stride=1,padding=1)#卷积层1
        self.pool = nn.MaxPool2d(2,2)#池化层
        self.conv2 = nn.Conv2d(6,16,kernel_size=23,stride=1,padding=1)#卷积层2
        self.fc1 = nn.Linear(16*10*10,1024)#两个池化,所以是7*7而不是14*14
        self.fc2 = nn.Linear(1024,512)
        self.fc3 = nn.Linear(512,2)
#         self.dp = nn.Dropout(p=0.5)
    def forward(self,x):# 这里定义前向传播的方法
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        x = x.view(-1, 16 * 10* 10)#将数据平整为一维的 
        x = F.relu(self.fc1(x))#激活函数
#         x = self.fc3(x)
#         self.dp(x)
        x = F.relu(self.fc2(x))   
        x = self.fc3(x)  
#         x = F.log_softmax(x,dim=1) NLLLoss()才需要,交叉熵不需要
        return x
# .view( )是一个tensor的方法,使得tensor改变size但是元素的总数是不变的。
#  第一个参数-1是说这个参数由另一个参数确定, 比如矩阵在元素总数一定的情况下,确定列数就能确定行数。
#  那么为什么这里只关心列数不关心行数呢,因为马上就要进入全连接层了,而全连接层说白了就是矩阵乘法,
#  你会发现第一个全连接层的首参数是16*5*5,所以要保证能够相乘,在矩阵乘法之前就要把x调到正确的size
net = CNN()
import torch.optim as optim

criterion = nn.CrossEntropyLoss()#同样是用到了神经网络工具箱 nn 中的交叉熵损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#也可以选择Adam优化方法
# optimizer = torch.optim.Adam(net.parameters(),lr=1e-2)
train_accs = []
train_loss = []
test_accs = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
for epoch in range(10):
    running_loss = 0.0 #定义一个变量方便我们对loss进行输出
    for i,data in enumerate(train_loader,0):#0是下标起始位置默认为0
        # 这里我们遇到了第一步中出现的trailoader,代码传入数据
         # enumerate是python的内置函数,既获得索引也获得数据,
        # data 的格式[[inputs, labels]]       
#         inputs,labels = data
        inputs,labels = data[0].to(device), data[1].to(device)
        #初始为0,清除上个batch的梯度信息
        optimizer.zero_grad()         

        #前向+后向+优化     
        outputs = net(inputs)
        loss = criterion(outputs,labels.long())
        loss.backward()# loss进行反向传播,
        optimizer.step() # 当执行反向传播之后,把优化器的参数进行更新,以便进行下一轮

        # loss 的输出,每个一百个batch输出,平均的loss
        running_loss += loss.item()
        if i%100 == 99:
            print('[%d,%5d] loss :%.3f' %
                 (epoch+1,i+1,running_loss/100))
            running_loss = 0.0
        train_loss.append(loss.item())

        # 训练曲线的绘制 一个batch中的准确率
        correct = 0
        total = 0
        _, predicted = torch.max(outputs.data, 1)
        total = labels.size(0)# labels 的长度
        correct = (predicted == labels).sum().item() # 预测正确的数目
        train_accs.append(100*correct/total)

print('Finished Training')
结果:
[1,  100] loss :0.693
[1,  200] loss :0.693
[1,  300] loss :0.691
[1,  400] loss :0.688
[1,  500] loss :0.682
[2,  100] loss :0.684
[2,  200] loss :0.677
[2,  300] loss :0.678
[2,  400] loss :0.657
[2,  500] loss :0.661
[3,  100] loss :0.641
[3,  200] loss :0.647
[3,  300] loss :0.638
[3,  400] loss :0.639
[3,  500] loss :0.633
[4,  100] loss :0.610
[4,  200] loss :0.620
[4,  300] loss :0.614
[4,  400] loss :0.615
[4,  500] loss :0.601
[5,  100] loss :0.601
[5,  200] loss :0.606
[5,  300] loss :0.602
[5,  400] loss :0.598
[5,  500] loss :0.593
[6,  100] loss :0.554
[6,  200] loss :0.580
[6,  300] loss :0.573
[6,  400] loss :0.591
[6,  500] loss :0.575
[7,  100] loss :0.553
[7,  200] loss :0.564
[7,  300] loss :0.557
[7,  400] loss :0.572
[7,  500] loss :0.544
[8,  100] loss :0.525
[8,  200] loss :0.531
[8,  300] loss :0.547
[8,  400] loss :0.555
[8,  500] loss :0.536
[9,  100] loss :0.520
[9,  200] loss :0.515
[9,  300] loss :0.523
[9,  400] loss :0.509
[9,  500] loss :0.533
[10,  100] loss :0.479
[10,  200] loss :0.492
[10,  300] loss :0.495
[10,  400] loss :0.486
[10,  500] loss :0.518
Finished Training
#计算准确准确率(可以删掉)
def draw_train_process(title,iters,costs,accs,label_cost,lable_acc):
    plt.title(title, fontsize=24)
    plt.xlabel("iter", fontsize=20)
    plt.ylabel("acc(\%)", fontsize=20)
    plt.plot(iters, costs,color='red',label=label_cost) 
    plt.plot(iters, accs,color='green',label=lable_acc) 
    plt.legend()
    plt.grid()
    plt.show()
train_iters = range(len(train_accs))
draw_train_process('training',train_iters,train_loss,train_accs,'training loss','training acc')

在这里插入图片描述

#(可以删掉)
from torchvision import datasets, transforms,utils
dataiter = iter(test_loader)
images, labels,one_hot = dataiter.next()

# print images
test_img = utils.make_grid(images)
test_img = test_img.numpy().transpose(1,2,0)
std = [0.5,0.5,0.5]
mean =  [0.5,0.5,0.5]
test_img = test_img*std+0.5
plt.imshow(test_img)
plt.show()
print('GroundTruth: ', ' '.join('%d' % labels[j] for j in range(20)))

在这里插入图片描述
GroundTruth: 0 1 0 1 1 1 1 1 1 0 1 0 0 1 0 0 0 0 0 0

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值