pytorch 自定义数据集载入(标签在csv文件里)

在跑别人的项目的过程中,遇到的第一个大障碍是自定义数据集加载。本文主要讲关于如何让PyTorch能读取自己的数据集,不涉及dataloader机制。

查阅了一些博客还有文章了解到,要让PyTorch能读取自己的数据集,只需要两步:

1. 制作图片数据的索引表

2. 构建Dataset子类

详细参考链接:https://zhuanlan.zhihu.com/p/52807406

然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。

困扰我的不是第二步而是第一步。。

想做一个人脸性别识别的项目,他的数据集是所有图片全部在一个image文件夹中,见图1,没有区分训练集train和测试集test文件夹,信息是存放在了两个csv文件中了,一个是train1.csv,另一个是test1.csv.

train1.csv文件中如下图2,是图片id和标签label,test1.csv文件中只有图片id,如下图3

                                                                                                                                                           图1

                                                                                                                                                                  图2

                                                                                                                                                                 图3

那么接下来就是制作图像索引列表啦,这里与上面知乎的文章里的情况不同,它里面使用的generate_txt函数是针对的已经分类文件夹的数据集,我要用的数据集的标签在csv文件里。

插一句:这里我又到b站去搜了一个视频,里面讲的很详细,也有与我这种类似的情况,是第四种。但是他写的太麻烦啦。。我看了代码之后懵了orz,而且认为有点舍近求远了?。。(参考链接在下面还请指正)

链接:https://www.bilibili.com/video/BV1354y1s7kQ?p=5

不就要整出个txt文件,里面存着图片路径和标签嘛!直接’手动‘写把。。

  • 制作图片数据的索引

读取图片路径,标签,保存到txt文件中,直接上代码。注意算好自己的路径,自己看好图片文件夹,csv文件在哪里

关于路径,这是我的项目文件结构

凑了一个直接从csv文件得到图片路径的函数方法
# data_dir = '../gender/' 项目根目录
# label_file = 'train1.csv'
# test_file = 'test1.csv'
#封装函数的话,就将以上内容作为参数从主文件里调用
def generate_txt(data_dir, label_file, test_file):
    # generate train.txt
    with open(os.path.join(data_dir, label_file), 'r') as f:#将data_dir和label_file路径拼
        lines =f.readlines()[0:]
        print('****************')
        print('input :', os.path.join(data_dir, label_file))
        print('start...')
        listText = open('../gender/train1.txt', 'a+') #创建并打开train1.txt文件,a+表示打开一个文件并追加内容
        for l in lines:
              tokens = l.rstrip().split(',') #这里注意,从csv里直接读进来是有,的,用split        
                                             #将其去掉并分割
              idx, label = tokens
              name = data_dir+'image/image/'+idx +'.jpg'+ ' ' +str(int(label))+ '\n'
              listText.write(name)
        listText.close()
        print('down!')
        print('****************')


    # generate test.txt 与train相比只是少了label,可以读行的时候直接将l1写入
    with open(os.path.join(data_dir, test_file), 'r') as f1:
        lines1 =f1.readlines()[0:] #表示从第1行,下标为0的数据行开始
        print('****************')
        print('input :', os.path.join(data_dir, test_file))
        print('start...')
        listText1 = open('../gender/test1.txt', 'a+') #创建并打开test1.txt文件,a+表示打开一个文件并追加内容
        for l1 in lines1:
              name1 = data_dir+'image/image/'+l1.rstrip() +'.jpg'+'\n' #rstrip()为了把右边的换行符删掉
              listText1.write(name1)
        listText1.close()
        print('down!')
        print('****************')

然后就生成了想要的txt文件啦

下一步就是如上面知乎中的文章里提到的,自定义dataset类了,

整个流程就是

1. 制作存储了图片的路径和标签信息的txt

2. 将这些信息转化为list,该list每一行元素对应一个样本

3. 通过getitem函数,读取数据和标签,并返回数据和标签

class MyDataset(Dataset):
    def __init__(self, txt, type, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        self.type = type
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()  # 分割成文件名和标签
            if self.type == "train":
                imgs.append((words[0], int(words[1])))
            else:
                imgs.append(words[0])
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        if self.type == "train":
            fn, label = self.imgs[index]
        else:
            fn = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        if self.type == "train":
            return img, label
        else:
            return img

    def __len__(self):
        return len(self.imgs)

关于getitem函数的详细解释在知乎文章里有,这里再引用一下

1、self.imgs 是一个list,也就是上面提到的list,self.imgs的一个元素是一个str,包含图片路径,图片标签,这些信息是从txt文件中读取

fn, label = self.imgs[index]

2、利用Image.open对图片进行读取,img类型为 Image ,mode=‘RGB’

img = Image.open(fn).convert('RGB')

3、对图片进行处理,这个transform里边可以实现 减均值,除标准差,随机裁剪,旋转,翻转,放射变换,等等操作

img = self.transform(img)

收工了~

然后开始数据集载入咯

# 制作dataset
train_data = MyDataset(txt='../gender/train1.txt',type = "train", transform=transform_train)
test_data  = MyDataset(txt='../gender/test1.txt', type = "test", transform=transform_test)
#test集可以直接dataloader
test_loader = DataLoader(test_data, batch_size=100, shuffle=False)
#训练集
#分为验证集(0.2)和训练集(0.8)
batch_size = 60 
validation_split = 0.2
shuffle_dataset = True
random_seed= 42
dataset_size = len(train_data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))  #np.floor向下取整
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)#打乱顺序
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
##上面这些参数的设置我没仔细查。。跟dataloader有关吧?

train_loader      = torch.utils.data.DataLoader(train_data, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                                sampler=valid_sampler)

之后就是网络的训练啦。完结!

emmm,我好像总把问题思考复杂,应该与我没有抓住问题的本质有关吧?。。。好像这个问题在别人那里很简单就得到了解决,然而我解决这个数据集载入问题考虑了好久。。

如果能有小伙伴指正,我会感激不尽滴!

 

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值