确保安装
- scikit-image
- numpy
Dataset和DataLoader都是Pytorch里面读取数据的工具。现在对这两种工具做一个概括和总结。
1.Dataset
一个例子:
# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
#创建子类
class subDataset(Dataset.Dataset):
#初始化,定义数据内容和标签
def __init__(self, Data, Label):
self.Data = Data
self.Label = Label
#返回数据集大小
def __len__(self):
return len(self.Data)
#得到数据内容和标签
def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.IntTensor(self.Label[index])
return data, label
# 主函数
if __name__ == '__main__':
dataset = subDataset(Data, Label)
print(dataset)
print('dataset大小为:&