Pytorch的猫狗识别代码——基于上一个手写汉字识别的代码

前言

这篇关于猫狗识别的代码是基于上一篇的手写汉字识别代码改写的。
之前有尝试了PyTorch的入门教程实战这篇博文里提到的代码,就是陈云老师那个代码,然后学习了一遍,遇到的一些错误都整理在了Pytorch猫狗大战遇见的问题总结这篇博客里。但是这个代码的结果很奇怪,主要还残留了下面两个问题

  1. 就是loss没有在visdom中显示出来,最终验证集验证的模型精度也没有显示,所以也不知道模型的好坏
  2. 最终输出的result.csv文件中,狗的概率就没有大于0.5的,也就是预测结果都是猫,这就很迷了。

实际上作为小白的我,陈云老师的代码让我有点捉摸不透,所以就还是用我熟悉的手写汉字识别的代码框架吧,然后再结合陈云老师的代码的一些内容,来个移花接木吧!😀

关于数据

数据集的分布是这样的:
在这里插入图片描述
我是直接将数据集test1文件夹和训练集train文件夹放在和代码文件code.py同级的目录下。
在这里插入图片描述
这是训练集的数据,一共有25000张图片。
在这里插入图片描述
这是测试集的数据,一共有12500张图片。

1. 库的导入

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import argparse  # 提取命令行参数
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
import torchvision.transforms as transforms
from torchnet import meter
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
import torchvision
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")

有些库貌似并没有用上,但是我也懒得删了,反正也不会有问题的。

2. 命令行参数设置

parse = argparse.ArgumentParser(description='Params for training. ')
# 数据集根目录
parse.add_argument('--root', type=str, default='B:/DogVsCat_1', help='path to data set')
# 模式,5选1,当default为train时,就会运行下面的trian函数,就只进行模型训练
# train为训练模型,test为预测test1文件样本的结果,validation为验证模型的预测精度,inference为预测少量样本,inference_one为预测单个样本
parse.add_argument('--mode', type=str, default='train', choices=['train', 'validation', 'test', 'inference', 'inference_one'])
# checkpoint 路径,用来保存模型训练的状态
parse.add_argument('--log_path', type=str, default=os.path.abspath('.') + '/log.pth', help='dir of checkpoints')
# 第一次训练时,需要将restore设为False,等到代码运行一段时间生成log.pth文件后,停止训练,将其改为True,再继续运行即可
parse.add_argument('--restore', type=bool, default=True, help='whether to restore checkpoints')

parse.add_argument('--batch_size', type=int, default=10, help='size of mini-batch')
parse.add_argument('--image_size', type=int, default=64, help='resize image')
# 迭代次数
parse.add_argument('--epoch', type=int, default=5)
# test数据预测结果输出文件名
parse.add_argument('--file_name', type=str, default='result.csv')
args = parse.parse_args()  # 解析参数

这里同样是用了argparse模块,具体不清楚的可以取看基于Pytorch的手写汉字识别中的3.2部分,这里就不再赘述。
同样需要强调的是,在第一次运行的时候,需要将restore设为False,等到代码运行一段时间生成log.pth文件后,停止训练,将其改为True,再继续运行即可

3. 自定义数据集

# 自定义数据集
class MyDataset(data.Dataset):

    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
        """
        self.test = test
        imgs = [os.path.join(root, img) for img in os.listdir(root)]  # 目录下每张图片的路径列表

        # test1: data/test1/8973.jpg
        # train: data/train/cat.10004.jpg
        # 将imgs列表按照图片名的序号排序
        if self.test:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('\\')[-1]))
        else:
            imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))

        imgs_num = len(imgs)  # 获取数据集大小

        # shuffle imgs
        np.random.seed(100)
        imgs = np.random.permutation(imgs)  # 随机排列序列

        # 划分训练集和验证集,比例为7:3
        if self.test:
            self.imgs = imgs  # 测试集就不做处理
        elif train:
            self.imgs = imgs[:int(0.7 * imgs_num)]  # 70%的训练集做训练集
        else:
            self.imgs = imgs[int(0.7 * imgs_num):]  # 30%的训练集做验证集
        # 数据转换操作
        if transforms is None:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                    std
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值