这里还是打算采用上一篇《实战1——tensorflow训练MNIST网络,pytorch预测-CSDN博客》中用到的本地数据集mnist.npz。
pytorch训练网络
数据集需要转换为pytorch的格式。因此通过这一篇博文,我们还将学习到如何定义自己的数据集和测试集,或者复用已有的数据集。
先加载数据集:
#先加载MNIST数据集
with np.load(path_to_mnist, allow_pickle=True) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
定义一个自定义的数据集类型:
class MyDataset(Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]