pytorch 中 __getitem__ ()和DataLoader

 

 

在构建DataLoader时,需要传入参数dataset,这里可以是自己自定义数据集类,比如上图myDataset

在DataLoader 送入torch中进行训练时,会自动调用数据集类的__getitem__()方法

class myDataset(Dataset):
    def __init__(self, csv_file, txt_file, root_dir, other_file):
        self.csv_data = pd.read_csv(csv_file)
        with open(txt_file, 'r') as f:
            data_list = f.readlines()
        self.txt_data = data_list
        self.root_dir = root_dir
 
    def __len__(self):
        return len(self.csv_data)
 
    def __getitem__(self, idx):
        data = (self.csv_data[idx], self.txt_data[idx])
        return data
 
dataiter = DataLoader(myDataset, batch_size=32, shuffle=True)

 

 

__getitem__()方法理解

如果在类中定义了__getitem__()方法,那么他的实例对象(假设为P)就可以这样P[key]取值。当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。


class DataTest():
    def __init__(self,id,address):
        self.id = id
        self.address = address
        self.d = {self.id : 1,
                  self.address:"172.0.0.1"
                 }
    def __getitem__(self, key):
        return "hello"

data = DataTest(1, "172.0.0.1")
print(data.__getitem__(2))      # hello

print(data[1])  # hello

 

 

如果类把某个属性定义为序列,可以使用__getitem__()输出序列属性中的某个元素.


#__getitem__
#如果类把某个属性定义为序列,可以使用__getitem__()输出序列属性中的某个元素.
class FruitShop():
     def __getitem__(self,i):
         return self.fruits[i] #可迭代对象


if __name__ == "__main__":
    shop = FruitShop()
    print(shop)                            #__main__.FruitShop instance
    shop.fruits = ["apple", "banana"]
    print(shop[1])                           #banana
    for item in shop:
        print(item)   # appale banana

 

 

  • 8
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值