Python创建用于分类的数据集(数据:图像 + 标签:文本)

转载自 https://blog.csdn.net/Teeyohuang/article/details/79587125

之前讲的例子,程序都是调用的datasets方法,下载的torchvision本身就提供的数据,那么如果想导入自己的数据应该怎么办呢?

本篇就讲解一下如何创建自己的数据集。

还有第二篇……Pytorch打怪路(三)Pytorch创建自己的数据集2

1.用于分类的数据集

以mnist数据集为例

这里的mnist数据集并不是torchvision里面的,而是我自己的以图片格式保存的数据集,因为我在测试STN时,希望自己再把这些手写体做一些形变,

所以就先把MNIST数据集转化成了jpg图片格式,然后做了一些形变,当然这不是重点。首先我们看一下我的数据集的情况:

如图所示,我的图片数据集确实是jpg图片

 

再看我的存储图片名和label信息的文本:

 

 

如图所示,我的mnist.txt文本每一行分为两部分,第一部分是具体路径+图片名.jpg

第二部分就是label信息,因为前面这部分图片都是0 ,所以他们的分类的label信息就是0

要创建你自己的 用于分类的 数据集,也要包含上述两个部分,1.图片数据集,2.文本信息(这个txt文件可以用python或者C++轻易创建,再此不详述)

2.代码

 

主要代码


 
 
  1. from PIL import Image
  2. import torch
  3. class MyDataset(torch.utils.data.Dataset): #创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
  4. def __init__(self,root, datatxt, transform=None, target_transform=None): #初始化一些需要传入的参数
  5. fh = open(root + datatxt, 'r') #按照传入的路径和txt文本参数,打开这个文本,并读取内容
  6. imgs = []                       #创建一个名为img的空列表,一会儿用来装东西
  7. for line in fh:                 #按行循环txt文本中的内容
  8. line = line.rstrip()     # 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
  9. words = line.split() #通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
  10. imgs.append((words[ 0],int(words[ 1]))) #把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定
  11.                                         # 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable
  12.         self.imgs = imgs
  13. self.transform = transform
  14. self.target_transform = target_transform
  15. def __getitem__(self, index):     #这个方法是必须要有的,用于按照索引读取每个元素的具体内容
  16. fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
  17. img = Image.open(root+fn).convert( 'RGB') #按照path读入图片from PIL import Image # 按照路径读取图片
  18. if self.transform is not None:
  19. img = self.transform(img) #是否进行transform
  20. return img,label   #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
  21. def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
  22. return len(self.imgs)
  23. #根据自己定义的那个勒MyDataset来创建数据集!注意是数据集!而不是loader迭代器
  24. train_data=MyDataset(txt=root+ 'train.txt', transform=transforms.ToTensor())
  25. test_data=MyDataset(txt=root+ 'test.txt', transform=transforms.ToTensor())

 
 
  1. #然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
  2. train_loader = DataLoader(dataset=train_data, batch_size= 64, shuffle= True)
  3. test_loader = DataLoader(dataset=test_data, batch_size= 64)

 

再补充一点代码,以便更好的理解 __getitem__这个方法

 

 


 
 
  1. for batch_index, data, target in test_loader:
  2. if use_cuda:
  3. data, target = data.cuda(), target.cuda()
  4. data, target = Variable(data, volatile= True), Variable(target)

这段代码是我从测试的部分中截取出来的,为什么直接能用for data, target In test_loader这样的语句呢?

其实这个语句还可以这么写:

for batch_index, batch in train_loader

        data, target = batch

这样就好理解了,因为这个迭代器每一次循环所得的batch里面装的东西,就是我在__getitem__方法最后return回来的

所以你想在训练或者测试的时候还得到其他信息的话,就去增加一些返回值即可,只要是能return出来的,就能在每个batch中读取到!

###############################################################################

有朋友可能想问,如果我的label信息不是数字而是图像呢?比如分割任务,它的label就是图像,这样的数据集的建立,也参考我的下一篇博文:

Pytorch打怪路(三)Pytorch创建自己的数据集2

  • 5
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值