class类中__getitem__的作用及其使用方法

我们在看别人写的代码时,在类中经常会看到__getitem__方法,这个方法的作用是,可以将类中的数据 像数组一样读出,以下进行代码演示:

在类中创建__getitem__方法,并使用数组形式读取类中的数据:

class Test():
    def __init__(self):
        self.a=[1,2,3,4,5]
    def __getitem__(self,idx):
        return(self.a[idx])
data=Test()
print(data)
print(data[0])

输出结果为:

<__main__.Test object at 0x000002BBB3256DF0>
1

在类中不创建__getitem__方法,并使用数组形式读取类中的数据:

class Test():
    def __init__(self):
        self.a=[1,2,3,4,5]
#     def __getitem__(self,idx):
#         return(self.a[idx])
data=Test()
print(data)
print(data[0])

输出结果报错:

<__main__.Test object at 0x000002BBB45C2C70>
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [20], in <cell line: 8>()
      6 data=Test()
      7 print(data)
----> 8 print(data[0])

TypeError: 'Test' object is not subscriptable

在类中不创建__getitem__方法,并不使用数组形式读取类中的数据:

class Test():
    def __init__(self):
        self.a=[1,2,3,4,5]
#     def __getitem__(self,idx):
#         return(self.a[idx])
data=Test()
print(data)

输出结果:

<__main__.Test object at 0x000002BBB43AFC70>

总结:

 类中的__getitem__方法是为了将类中的数据可以用数组的形式读出,如果不使用数组的方法读类中的数据,那么就不需要在类中创建__getitem__方法

如何将__getitem__与dataloader结合使用

import torch
import numpy as np
from torch.utils.data import Dataset
 
# 创建MyDataset类
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.data = torch.from_numpy(x).float()
        self.label = torch.LongTensor(y)
 
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx], idx
 
    def __len__(self):
        return len(self.data)
 
Train_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
Train_label = np.array([10, 11, 12, 13])
TrainDataset = MyDataset(Train_data, Train_label) # 创建实例对象
print('len:', len(TrainDataset))
 
# 创建DataLoader
loader = torch.utils.data.DataLoader(
    dataset=TrainDataset,
    batch_size=2,
    shuffle=False,
    num_workers=0,
    drop_last=False)
 
# 按batchsize打印数据
for batch_idx, (data, label, index) in enumerate(loader):
    print('batch_idx:',batch_idx, '\ndata:',data, '\nlabel:',label, '\nindex:',index)
    print('---------')

输出结果:

len: 4
 
batch_idx: 0 
data: tensor([[1., 2., 3.],
        [4., 5., 6.]]) 
label: tensor([10, 11]) 
index: tensor([0, 1])
---------
batch_idx: 1 
data: tensor([[ 7.,  8.,  9.],
        [10., 11., 12.]]) 
label: tensor([12, 13]) 
index: tensor([2, 3])
---------

https://blog.csdn.net/weixin_43863869/article/details/125602643

  • 5
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
train_test_split函数是sklearn库中的函数,如果想要使用pytorch实现相同的功能,可以使用torch.utils.data中的SubsetRandomSampler类对数据集进行划分。 具体实现步骤如下: 1. 首先将数据集封装成一个Dataset对象,可以使用PyTorch提供的Dataset类或自定义一个Dataset类。 2. 定义一个SubsetRandomSampler对象,指定训练集和测试集的索引。 3. 使用DataLoader类将数据集和Sampler对象进行组合,实现数据的批量读取。 下面是一个示例代码: ```python import torch from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler import numpy as np # 定义数据集类 class MyDataset(Dataset): def __init__(self, X, y): self.X = torch.FloatTensor(X) self.y = torch.LongTensor(y) def __len__(self): return len(self.X) def __getitem__(self, idx): return self.X[idx], self.y[idx] # 划分训练集和测试集 ts = 0.2 random_state = 42 X = np.random.rand(100, 10) y = np.random.randint(0, 2, size=(100,)) num_train = int((1 - ts) * len(X)) indices = np.arange(len(X)) np.random.seed(random_state) np.random.shuffle(indices) train_indices, test_indices = indices[:num_train], indices[num_train:] # 构建数据集和Sampler对象 dataset = MyDataset(X, y) train_sampler = SubsetRandomSampler(train_indices) test_sampler = SubsetRandomSampler(test_indices) # 使用DataLoader读取数据 batch_size = 16 train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler) # 打印训练集和测试集的大小 print(len(train_sampler)) print(len(test_sampler)) ``` 在这个示例代码中,我们定义了一个MyDataset类来封装数据集,其中__getitem__方法返回一个数据样本及其对应的标签。然后,我们使用numpy库将原始数据集随机划分成训练集和测试集,并使用SubsetRandomSampler类对索引进行抽样。最后,我们使用DataLoader类将数据集和Sampler对象进行组合,实现批量读取数据。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

铁灵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值