Pytorch实现加载自己的数据集

我们经常会遇到这样的问题,就是如何使用自己的数据集,把标签和图片对应起来,然后转化成一个一个批次送进网络。其实在pytorch中已经为我们封装好了各种库,只需要我们添加相应的处理就好。

import torch.utils.data   #子类化数据
import torch
from tochvision import transforms   #数据处理

定义自己的dataset类:

class MyTrainData(torch.utils.data.Dataset)

该类包括初始化参数,index获取图片和标签,以及最后返回数据集的长度。

#encoding:utf-8
import torch.utils.data as data
import torch
from torchvision import transforms
 
class MyTrainData(torch.utils.data.Dataset) #子类化
  def __init__(self, root, transform=None, train=True): #第一步初始化各个变量
 
    self.root = root   
    self.train = train
 
  def __getitem__(self, idx): #第二步装载数据,返回[img,label],idx就是一张一张地读取
      # get item  获取  数据 
 
      img = imread(img_path) #img_path根据自己的数据自定义,灵活性很高
      img = torch.from_numpy(img).float() #需要转成float
 
      gt = imread(gt_path)  #读取gt,如果是分类问题,可以根据文件夹或命名赋值 0 1  
      gt = torch.from_numpy(gt).float()
 
      return img, gt #返回  一一对应
 
  def __len__(self):
    return len(self.imagenumber) #这个是必须返回的长度

transform类函数定义了各种数据的操作,比如旋转,剪切,缩放等等,我们可以根据自己的需要添加需要的函数。
1、完整代码

#encoding:utf-8
import torch.utils.data as data
import torch
 
from scipy.ndimage import imread
import os
import os.path
import glob
 
from torchvision import transforms
 
def make_dataset(root, train=True): #读取自己的数据的函数
 
  dataset = []
 
  if train:
    dirgt = os.path.join(root, 'train_data/groundtruth') 
    dirimg = os.path.join(root, 'train_data/imgs')
 
    for fGT in glob.glob(os.path.join(dirgt, '*.jpg')):
    # for k in range(45)
      fName = os.path.basename(fGT)    
      fImg = 'train_ori'+fName[8:]
      dataset.append( [os.path.join(dirimg, fImg), os.path.join(dirgt, fName)] )
 
  return dataset
 
#自定义dataset的框架
class MyTrainData(data.Dataset):   #需要繼承data.Dataset
 
  def __init__(self, root, transform=None, train=True): #初始化文件路進或文件名
    self.train = train
    if self.train:
      self.train_set_path = make_dataset(root, train)
 
  def __getitem__(self, idx):
    if self.train:
      img_path, gt_path = self.train_set_path[idx]
 
      img = imread(img_path)
      img = np.atleast_3d(img).transpose(2, 0, 1).astype(np.float32)
      img = (img - img.min()) / (img.max() - img.min())
      img = torch.from_numpy(img).float()
 
      gt = imread(gt_path)
      gt = np.atleast_3d(gt).transpose(2, 0, 1)
      gt = gt / 255.0
      gt = torch.from_numpy(gt).float()
 
      return img, gt  
 
  def __len__(self):
 
    return len(self.train_set_path)

2、另一种方式

import torch.nn.functional as F
import torch
import torch 
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch.optim as optim
import os

#torch.cuda.set_device(gpu_id)#使用GPU
learning_rate = 0.0001

#数据集的设置*****************************************************************************************************************
root =os.getcwd()+ '/data1/'#调用图像

#定义读取文件的格式
def default_loader(path):
    return Image.open(path).convert('RGB')

#首先继承上面的dataset类。然后在__init__()方法中得到图像的路径,然后将图像路径组成一个数组,这样在__getitim__()中就可以直接读取:
class MyDataset(Dataset): #创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
	def __init__(self,txt, transform=None,target_transform=None, loader=default_loader): #初始化一些需要传入的参数
		super(MyDataset,self).__init__()#对继承自父类的属性进行初始化
		fh = open(txt, 'r')#按照传入的路径和txt文本参数,打开这个文本,并读取内容
		imgs = []
		for line in fh: #迭代该列表#按行循环txt文本中的内
			line = line.strip('\n')
			line = line.rstrip('\n')# 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
			words = line.split() #用split将该行分割成列表  split的默认参数是空格,所以不传递任何参数时分割空格
			imgs.append((words[0],int(words[1]))) #把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定 
                                                 # 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable       
		self.imgs = imgs
		self.transform = transform
		self.target_transform = target_transform
		self.loader = loader        
        
	def __getitem__(self, index):#这个方法是必须要有的,用于按照索引读取每个元素的具体内容
		fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
		img = self.loader(fn) # 按照路径读取图片
		if self.transform is not None:
			img = self.transform(img) #数据标签转换为Tensor
		return img,label#return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
	def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
		return len(self.imgs)
 #根据自己定义的那个MyDataset来创建数据集!注意是数据集!而不是loader迭代器
#*********************************************数据集读取完毕********************************************************************
#图像的初始化操作
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((227,227)),
    transforms.ToTensor(),
])
text_transforms = transforms.Compose([
    transforms.RandomResizedCrop((227,227)),
    transforms.ToTensor(),
])

#数据集加载方式设置
train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data = MyDataset(txt=root+'text.txt', transform=transforms.ToTensor())
#然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True,num_workers=4)
test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4)
print('num_of_trainData:', len(train_data))
print('num_of_testData:', len(test_data))
  • 4
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值