在构建DataLoader时,需要传入参数dataset,这里可以是自己自定义数据集类,比如上图myDataset
在DataLoader 送入torch中进行训练时,会自动调用数据集类的__getitem__()方法
class myDataset(Dataset):
def __init__(self, csv_file, txt_file, root_dir, other_file):
self.csv_data = pd.read_csv(csv_file)
with open(txt_file, 'r') as f:
data_list = f.readlines()
self.txt_data = data_list
self.root_dir = root_dir
def __len__(self):
return len(self.csv_data)
def __getitem__(self, idx):
data = (self.csv_data[idx], self.txt_data[idx])
return data
dataiter = DataLoader(myDataset, batch_size=32, shuffle=True)
__getitem__()方法理解
如果在类中定义了__getitem__()方法,那么他的实例对象(假设为P)就可以这样P[key]取值。当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。
class DataTest():
def __init__(self,id,address):
self.id = id
self.address = address
self.d = {self.id : 1,
self.address:"172.0.0.1"
}
def __getitem__(self, key):
return "hello"
data = DataTest(1, "172.0.0.1")
print(data.__getitem__(2)) # hello
print(data[1]) # hello
如果类把某个属性定义为序列,可以使用__getitem__()输出序列属性中的某个元素.
#__getitem__
#如果类把某个属性定义为序列,可以使用__getitem__()输出序列属性中的某个元素.
class FruitShop():
def __getitem__(self,i):
return self.fruits[i] #可迭代对象
if __name__ == "__main__":
shop = FruitShop()
print(shop) #__main__.FruitShop instance
shop.fruits = ["apple", "banana"]
print(shop[1]) #banana
for item in shop:
print(item) # appale banana