基于pytorch实现Resnet对本地数据集的训练

  本文需要具备python编辑器和pytorch深度学习框架的语句基础知识

目录

文章目录

前言

一、dataset.py文件

二、network.py文件

三、train.py

结果与总结


前言

       本文是使用pycharm下的pytorch框架编写一个训练本地数据集的Resnet深度学习模型,其一共有两百行代码左右,分成mian.py、network.py、dataset.py以及train.py文件,功能是对本地的数据集进行分类。本文介绍逻辑是总分形式,即首先对总流程进行一个概括,然后分别介绍每个流程中的实现过程(代码+流程图+文字的介绍)。

       对于整个项目的流程首先是加载本地数据集,然后导入Resnet网络,最后进行网络训练。整体来说一个完整的小项目,难度并不高,需要有一定的pytorch语句以及深度学习的基础。

       mian.py文件是该项目的总文件,也是训练网络模型的运行文件,文本的介绍流程是随着该文件一 一对代码进行介绍。

  main.py代码如下所示:

from dataset import data_dataloader    #电脑本地写的读取数据的函数
from torch import nn                   #导入pytorch的nn模块
from torch import optim                #导入pytorch的optim模块
from network import Res_net            #电脑本地写的网络框架的函数
from train import train                #电脑本地写的训练函数

def main():
    # 以下是通过Data_dataloader函数输入为:数据的路径,数据模式,数据大小,batch的大小,有几线并用 (把dataset和Dataloader功能合在了一起)
    train_loader = data_dataloader(data_path='./data', mode='train', size=64, batch_size=24, num_workers=4)
    val_loader = data_dataloader(data_path='./data', mode='val', size=64, batch_size=24, num_workers=2)
    test_loader = data_dataloader(data_path='./data', mode='test', size=64, batch_size=24, num_workers=2)

    # 以下是超参数的定义
    lr = 1e-4           #学习率
    epochs = 10         #训练轮次

    model = Res_net(2)  # resnet网络
    optimizer = optim.Adam(model.parameters(), lr=lr)  # 优化器
    loss_function = nn.CrossEntropyLoss()  # 损失函数

    # 训练以及验证测试函数
    train(model=model, optimizer=optimizer, loss_function=loss_function, train_data=train_loader, val_data=val_loader,test_data= test_loader, epochs=epochs)

if __name__ == '__main__':
    main()

 main.py流程图如图1所示:

图 1 main.py 代码流程图

一、dataset.py文件

       main.py()前五行分别是导入相应的模块,其中dataset,network以及train是本地编写的文件。在mian()函数中的前几行代码中,我们使用dataset.py文件中的Data_dataloader函数导入训练集、验证集和测试集。Dataset文件是导入我们自己的本地数据库,其功能是得到所有的数据,将其变成pytorch能够识别的tensor数据,然后得到图片。

       dataset.py文件代码如下所示:

import torch
import os,glob
import random
import csv
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader

# 第一部分:通过三个步骤得到输出的tensor类型的数据
class Dataset_self(Dataset):                    #如果是nn.moduel 则是编写网络模型框架,这里需要继承的是dataset的数据,所以括号中的是Dataset
    #第一步:初始化
    def __init__(self,root,mode,resize,):       #root是文件根目录,mode是选择什么样的数据集,resize是图像重新调整大小
        super(Dataset_self, self).__init__()
        self.resize = resize
        self.root = root
        self.name_label = {}       #创建一个字典来保存每个文件的标签
        #首先得到标签相对于的字典(标签和名称一一对应)
        for name in sorted(os.listdir(os.path.join(root))):     #排序并且用列表的形式打开文件夹
            if not os.path.isdir(os.path.join(root,name)):      #不是文件夹就不需要读取
                continue
            self.name_label[name] = len(self.name_label.keys())  #每个文件的名字为name_Label字典中有多少对键值对的个数
        #print(self.name_label)
        self.image,self.label = self.make_csv('images.csv')       #编写一共函数来读取图片和标签的路径
        #在得到image和label的基础上对图片数据进行一共划分  (注意:如果需要交叉验证就不需要验证集,只划分为训练集和测试集)
        if mode == 'train':
            self.image ,self.label= self.image[:int(0.6*len(self.image))],self.label[:int(0.6*len(self.label))]
        if mode == 'val':
            self.image ,self.label= self.image[int(0.6*len(self.image)):int(0.8*len(self.image))],self.label[int(0.6*len(self.label)):int(0.8*len(self.label))]
        if mode == 'test':
            self.image ,self.label= self.image[int(0.8*len(self.image)):],self.label[int(0.8*len(self.label)):]
    # 获得图片和标签的函数
    def make_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):  #如果不存在汇总的目录就新建一个
            images = []
            for image in self.name_label.keys():                            # 让image到name_label中的每个文件中去读取图片
                images += glob.glob(os.path.join(self.root,image,'*jpg'))   #加* 贪婪搜索关于jpg的所有文件
            #print('长度为:{},第二张图片为:{}'.format(len(images),images[1]))
            random.shuffle(images)                                         #把images列表中的数据洗牌
            # images[0]: ./data\ants\382971067_0bfd33afe0.jpg
            with open(os.path.join(self.root,filename),mode='w',newline='') as f :  #创建文件
                writer = csv.writer(f)
                for image in im
  • 11
    点赞
  • 88
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值