pytorch学习1:如何加载自己的训练数据

Pytorch中文文档已出(http://pytorch-cn.readthedocs.io/zh/latest/)。第一篇博客献给了pytorch,主要是为了整理自己的思路。

原来使用caffe,总是要编译,经历了无数的坑。当开始接触pytorch时,果断拔草caffe。

学习Pytorch最好有一些深度学习理论基础才更好开,废话不多说,进入主题。

1 先有个框框,再往里面填东西

当训练一个神经网络的时候,我们需要有数据,有模型,并且需要设置训练的参数。为了不乱,我们最好分别定义三个文件,分别是:数据准备和预处理traindataset.py+编写模型model.py+如何训练main.py(xx.py,xx自己可任意取名)。

今天我们只讲数据准备与预处理阶段:traindataset.py(怎样命名无所谓,as you like)。这个文件的作用是什么呢?

统一将图像(或矩阵)返回成torch能处理的[original_iamges.tensor,label.tensor]

我们先跳跃一下看中文介绍是如何导入数据:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

我们一般关注DataLoader四个参数:

dataset, batch_size, shuffle, num_workers=0

batch_size是你批处理数目,shuffle是否每个epoch都打乱,workers是载入数据的线程数(请查看中文文档对每个参数的解释)

我们具体看看“dataset”——加载数据的数据集。

这个dataset是 [original_iamges.tensor,label.tensor] 之类的,我们定义的“traindataset.py”就是产生这个dataset的。然后只需在main.py 文件import就可调用!

from traindataset import *

2 定义一个py文件产生我们自己的dataset

这个py文件一定要

1:能输入自己的数据路径 2:还得预处理吧,比如的裁剪啊~

step 1:先导入你肯定需要的库路径

import torch.utils.data 
import torch
from tochvision import transforms

torch.utils.data模块是子类化你的数据

transforms库对数据预处理

step 2:自定义dataset类(子类化你的数据)

class MyTrainData(torch.utils.data.Dataset)

这里继承了torch.utils.data.Dataset这个类,我们看看这个类在中文文档中介绍:

所有其他数据集都应该进行子类化。所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。当然还有个初始化__init__()

类:属性+方法,__init__()就是定义自己的属性

我们脸谱化py文件,再往里面加东西(以下为基础框架):

#encoding:utf-8
import torch.utils.data as data
import torch
from torchvision import transforms

class MyTrainData(torch.utils.data.Dataset) #子类化
  def __init__(self, root, transform=None, train=True): #第一步初始化各个变量

    self.root = root   
    self.train = train

  def __getitem__(self, idx): #第二步装载数据,返回[img,label],idx就是一张一张地读取
      # get item  获取  数据 

      img = imread(img_path) #img_path根据自己的数据自定义,灵活性很高
      img = torch.from_numpy(img).float() #需要转成float

      gt = imread(gt_path)  #读取gt,如果是分类问题,可以根据文件夹或命名赋值 0 1  
      gt = torch.from_numpy(gt).float()

      return img, gt #返回  一一对应

  def __len__(self):
    return len(self.imagenumber) #这个是必须返回的长度


现在往框框里面填

(1)是否transform如裁剪、归一化、旋转等?如果要transform则还需要区分test和train。比如我train需要 随机翻转,但是test则不需要操作 

(2)如何做到一张一张对应读取图片? 可以自定义这些函数

以下贴出完整代码:

#encoding:utf-8
import torch.utils.data as data
import torch

from scipy.ndimage import imread
import os
import os.path
import glob

from torchvision import transforms

def make_dataset(root, train=True): #读取自己的数据的函数

  dataset = []

  if train:
    dirgt = os.path.join(root, 'train_data/groundtruth') 
    dirimg = os.path.join(root, 'train_data/imgs')

    for fGT in glob.glob(os.path.join(dirgt, '*.jpg')):
    # for k in range(45)
      fName = os.path.basename(fGT)    
      fImg = 'train_ori'+fName[8:]
      dataset.append( [os.path.join(dirimg, fImg), os.path.join(dirgt, fName)] )

  return dataset

#自定義dataset的框架
class MyTrainData(data.Dataset):   #需要繼承data.Dataset

  def __init__(self, root, transform=None, train=True): #初始化文件路進或文件名
    self.train = train
    if self.train:
      self.train_set_path = make_dataset(root, train)

  def __getitem__(self, idx):
    if self.train:
      img_path, gt_path = self.train_set_path[idx]

      img = imread(img_path)
      img = np.atleast_3d(img).transpose(2, 0, 1).astype(np.float32)
      img = (img - img.min()) / (img.max() - img.min())
      img = torch.from_numpy(img).float()

      gt = imread(gt_path)
      gt = np.atleast_3d(gt).transpose(2, 0, 1)
      gt = gt / 255.0
      gt = torch.from_numpy(gt).float()

      return img, gt  

  def __len__(self):

    return len(self.train_set_path)
   

这里的py文件需要在最后main.py文件中调用,所以root我并没有赋值,我会在main,py中赋值。

这里我并没有用到“transform”进行预处理,如果你想用的话,在__getitem__()下面,return img,gt前重新赋值

img = transforms.ToTensor(img)以及gt = transforms.ToTensor(gt)

这需要注意的是,查看中文文档transforms库有哪些变换,如果有需要涉及参数的如CenterCrop(size),需要先实参化,如

crop = transforms.CenterCrop(10);再使用:img = crop(img)

  • 16
    点赞
  • 76
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值