之前遇到的问题是,我自己定义了dataset的类,类似于下面的代码
class DealDataset(Dataset):
"""
下载数据、初始化数据,都可以在这里完成
"""
def __init__(self):
xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32) # 使用numpy读取数据
self.x_data = torch.from_numpy(xy[:, 0:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
self.len = xy.shape[0]
def __getitem__(self, index):
x_data=self.x_data[index]
y_data=self.y_data[index]
return {'x_data':x_data,'y_data':y_data}
def __len__(self):
return self.len
这样就在读取上非常迷惑,不知道用enumerate和tqdm要怎么读数据,搞清楚后在这里简要记录一下对应关系
1.return {'x_data':x_data,'y_data':y_data},
目前只会用enemerate读取
for i, data in enumerate(train_loader):
x_data, y_data= data['x_data'], data['y_data']
2.把return改变,改为return self.x_data[index],self.y_data[index]
这样tqdm读取
for x_data,y_data in tqdm(train_loader):
enumerate读取
for idx,data in enumerate(train_loader):
x_label=data[0]
y_label=data[1]
另外,除了放在batch那里,tqdm也可以放在epoch的循环那里
for epoch in tqdm(range(100)):