前言
视频链接
配合食用 - pytorch图像分类篇:3.搭建AlexNet并训练花分类数据集
model.py
dropout
flatten
train.py
加载数据集实用!(顶)
查看数据集
test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.next()
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# print labels
print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# show images
imshow(utils.make_grid(test_image))
predict.py