个人数据集转变为pytorch数据集

图片数据集存储为该形式

数据集转换:下一篇讲ray框架下pytorch模型训练时调用该模块

import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib.image as mpimg


# 对所有图片生成path-label map.txt 这个程序可根据实际需要适当修改
def generate_map(root_dir, n):  # root_dir为D:/tmp/photo
    # 得到当前绝对路径
    current_path = os.path.abspath('data')
    # os.path.dirname()向前退一个路径
    father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".")

    for idx in range(n):
        subdir = os.path.join(root_dir, '%d/' % idx)
        all_name = []
        for file_name in os.listdir(subdir):
            all_name.append(file_name)
        len_all_name = len(all_name)
        # 划分训练验证测试集
        split_1 = int(len_all_name * 0.6)
        split_2 = int(len_all_name * 0.8)
        train_name = all_name[:split_1]
        val_name = all_name[split_1:split_2]
        test_name = all_name[split_2:]

        # 将训练、验证、测试集的路径和标签写入不同的txt文件
        with open(os.path.join(root_dir, 'trainmap.txt'), 'w') as wfp1:
            for i in range(len(train_name)):
                abs_name = os.path.join(father_path, subdir, train_name[i])
                # linux_abs_name = abs_name.replace("\\", '/')
                wfp1.write('{file_dir} {label}\n'.format(file_dir=abs_name, label=idx))
        with open(os.path.join(root_dir, 'valmap.txt'), 'w') as wfp2:
            for i in range(len(val_name)):
                abs_name = os.path.join(father_path, subdir, val_name[i])
                # linux_abs_name = abs_name.replace("\\", '/')
                wfp2.write('{file_dir} {label}\n'.format(file_dir=abs_name, label=idx))
        with open(os.path.join(root_dir, 'testmap.txt'), 'w') as wfp3:
            for i in range(len(test_name)):
                abs_name = os.path.join(father_path, subdir, test_name[i])
                # linux_abs_name = abs_name.replace("\\", '/')
                wfp3.write('{file_dir} {label}\n'.format(file_dir=abs_name, label=idx))


# 实现MyDatasets类
class MyDatasets(Dataset):

    def __init__(self, dir, method):
        # 获取数据存放的dir
        # 例如d:/images/
        self.data_dir = dir
        # 用于存放(image,label) tuple的list,存放的数据例如(d:/image/1.png,4)
        self.image_target_list = []
        # 从dir--label的map文件中将所有的tuple对读取到image_target_list中
        # map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路径最好是绝对路径
        with open(os.path.join(dir, method), 'r') as fp:
            content = fp.readlines()
            # s.rstrip()删除字符串末尾指定字符(默认是字符)
            # 得到 [['d:/.../image_data/1/3.jpg', '1'], ...,]
            str_list = [s.rstrip().split() for s in content]
            # 将所有图片的dir--label对都放入列表,如果要执行多个epoch,可以在这里多复制几遍,然后统一shuffle比较好
            self.image_target_list = [(x[0], int(x[1])) for x in str_list]

    def __getitem__(self, index):
        image_label_pair = self.image_target_list[index]
        # 按path读取图片数据,并转换为图片格式例如[3,32,32]
        # 可以用别的代替
        img = mpimg.imread(image_label_pair[0])
        img = np.resize(img, (3, 32, 32))
        img = torch.from_numpy(img).float()
        return img, image_label_pair[1]

    def __len__(self):
        return len(self.image_target_list)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值