pytorch自定义TensorDataset

文章介绍了如何在PyTorch中使用TensorDataset来构建数据集,并展示了如何通过继承Dataset类来自定义数据集,以添加额外的属性。通过这种方式,可以更好地管理和操作训练数据和标签,同时支持数据加载器(DataLoader)进行批量训练。示例代码包括了创建自定义数据集以及使用DataLoader进行迭代的过程。
摘要由CSDN通过智能技术生成

在pytorch中如果仅仅是训练数据和标签,完全可以使用TensorDataset进行构造

            train_set = torch.utils.data.TensorDataset(torch.FloatTensor(self.train_X), torch.from_numpy(self.train_label).long())
            train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=0,shuffle=True)

但是这种做法有一个问题,就是无法对数据集进行自定义一些属性,所以我想自定义这个TensorDataset

import torch
from torch.utils.data import Dataset

class CustomTensorDataset(Dataset):
    def __init__(self, dataset):
        [data_X, data_y] = dataset
        X_tensor, y_tensor = torch.tensor(data_X), torch.tensor(data_y)
        #X_tensor, y_tensor = Tensor(data_X), Tensor(data_y)
        tensors = (X_tensor, y_tensor)
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.train_labels = y_tensor.long()
        self.test_labels = y_tensor.long()
        

    def __getitem__(self, index):
        x = self.tensors[0][index]

        y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)
    
train_dataset = CustomTensorDataset([X,y])
test_dataset = CustomTensorDataset([X,y])
print(train_dataset)

可以看到,这里的train_dataset是有属性的,例如tensors,train_labels,test_labels都是可以访问的

dataloader进行iter

import numpy as np 
from torch.utils.data import TensorDataset,DataLoader
torch.manual_seed(1)
x = np.arange(100).reshape(20,5)
y = np.arange(20)

ds = TensorDataset(torch.from_numpy(x),torch.from_numpy(y))


dl= DataLoader(ds,batch_size=4,shuffle=True)
for idx,(x,y) in enumerate(dl):
    print(idx)
    print(x)
    print(y)
    print(x.shape)
    print(y.shape)
    break

结果如下
在这里插入图片描述

import numpy as np 
from torch.utils.data import TensorDataset,DataLoader
torch.manual_seed(1)
x = np.arange(100).reshape(20,5)
y = np.arange(20)

ds = TensorDataset(torch.from_numpy(x),torch.from_numpy(y))


dl= DataLoader(ds,batch_size=4,shuffle=True)
iter_dl = iter(dl)

xx,yy = next(iter_dl)
print(xx)
print(yy)
print(xx.shape)
print(yy.shape)

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值