Pytorch之数据加载以及处理
前言
汲取自pytorch-DATA LOADING AND PROCESSING TUTORIAL
这里主要介绍了数据集的处理,从类的构造角度阐述了如何自己打造各个函数,不过到最后还是给出了pytorch自带的包,给我们省了不少事
包的导入
Dataset class
torch.utils.data.Dataset 是一个表示数据集的抽象类
我们自定义的dataset需要继承Dataset并且重载以下的方法
- __len__ ,即可调用len(dataset)返回数据集的大小
- __getitem__ ,即可使用dataset[i]访问数据集
例子:
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {