pytorch自定义输入数据(复写Dataset)

数据加载

对于图片数据加载,使用ImageFolder读取训练集和测试集的文件夹,其中训练集文件夹下分别有多个子文件夹,文件夹的名字即为该类型图片的标签(如图)
图片文件存储的方式
获得图片数据train_dataset后,放入DataLoader迭代器,其中有三个自定义参数:batch_size 、 shuffle 、 num_workers,分别表示批次大小、数据是否打乱和使用的线程数。

	from torchvision import transforms, datasets

    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

Dataset讲解

from torch.utils.data import Dataset

Dataset有三个主要的方法:__init__ (self ,)、__getitem__(self , idx) 、__len__(self)

1、通过__len__(self):函数获得数据集的大小,从而在放入DataLoader后能根据batch_size划分数据;

2、__getitem__(self , idx):其中idx为每个batch_size的下标,从而获得每条数据;

自定义Dataset

目前有关于电影信息用csv格式保存和电影海报图片,图片的文件名为对应电影的id,如图:

电影csv文件
电影海报文件
若想同时读取以上文本数据和图片数据,就得复写Dataset类:

class MovieDataset(Dataset):	
	# data_dir:csv表的路径 ;root_dir:图片的路径
    def __init__(self, data_dir, root_dir , transform=None):
        # self.data = pd.read_csv(csv_file) # csv总文件读取
        # 数据的位置
        self.features, self.targets_values, self.ratings = pickle.load(open(data_dir, mode='rb'))

        self.uid, self.user_gender, self.user_age, self.user_job = self.features.take(0, 1), self.features.take(2, 1), self.features.take(3,1), self.features.take(4, 1)

        self.movie_id, self.movie_categories, self.movie_titles, self.intro ,self.targets = self.features.take(1,1)	, self.features.take(6,1), self.features.take(5,1) , self.features.take(7,1) , self.targets_values

        self.root_dir = root_dir # 图片路径
        self.transform = transforms.Compose([
					     transforms.Resize((224, 224)),
					     transforms.RandomHorizontalFlip(),
					     transforms.ToTensor(),
					     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
					     ])# 图像规范化

    def __len__(self):
        return len(self.movie_id) 

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, str(self.movie_id[idx])+'.jpg') # 根据电影id获得对应图片位置
        image = Image.open(img_name) 
        image = image.convert('RGB')
        if self.transform:
            image = self.transform(image)

        user_data = self.get_user_data() # 获得用户数据
        movie_data = self.get_movie_data() # 获得电影数据

        # 使用列表存储输入数据x和真实结果y
        input=[]
        input.append(image)
        # (user_id , user_gender , user_age , user_job)
        input.append(user_data[idx])
        # (movie_id , categories , title , intro , intro_lengths)
        input.append(movie_data[idx])
        # (rating)
        input.append(self.targets[idx])

        return input

这里明显可以看出__getitem__(self , idx)中的idx为csv表中的索引,可以根据索引将不同类型的数据放入DataLoader里。

for step, X in enumerate(train_bar):
	batch_x,batch_y = X[:-1] , X[-1] #获得输入特征和真实结果

总结

复写Dataset可以根据自己的需求将不同的数据放入DataLoader里,复写的类中,需同时有__init__ (self ,)、__getitem__(self , idx) 、__len__(self)三个方法。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值