Dataset
from torch.utils.data import Dataset
import pandas as pd
import os
class myDataset(Dataset):
def __init__(self, file_dir):
self.filepaths = os.listdir(file_dir)
dfs = []
for filepath in self.filepaths:
dfs.append(pd.read_csv(os.path.join(file_dir,filepath),header=None))
self.data = pd.concat(dfs, ignore_index=True)
self.dlength = len(self.data)
def __getitem__(self, idx):
line = self.data.iloc[idx].to_numpy()
return line
def __len__(self):
return self.dlength
这里注意Dataset要大写,之前小写把自己坑了。file_dir填文件夹名,这段就是在init时用pandas把文件夹下的所有csv文件拼接起来存放在self.data里。而调用getitem就是将data的某一行返回。这个重载的函数应该会被dataloader调用来取数据,打包成batch之类的等等。
题外话:现在觉得,纯数据干嘛用pandas,直接用numpy读就好了。需要利用dataframe的时候再用pandas。
Dataloader
dataloader就用pytorch.utils.data里的Dataloader就行了。
train_dataset = myDataset(file_dir=train_dir)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True)
读取数据
for i, line in enumerate(train_loader, start=1):
X = line[:,:-1]
target = line[:,-1]
没什么好说的,都是些小bug,跟着bug提示搜一搜就能解决。
感觉得手撸交叉熵了