看了许多关于PyTorch的入门文章,大抵是从torchvision.datasets中自带的数据集进行训练,导致很难把PyTorch运用于自己的数据集上,真正地灵活运用PyTorch。
这里我采用从Kaggle上下载的猫狗数据集,利用自定义数据集训练自己的二分类神经网络。
解压后,一个文件里面有12500张图,猫狗各一半,文件名类似于这样:cat.0.jpg、dog.12499.jpg
因为只是练手,所以不用这么大的,仅仅采用子数据集。
利用Python的os库,将数据集进行拆分。分为train与test两个文件架,每个里面都有cats和dogs两个文档。train里面每种动物有1000张图,test里面每种动物有500张图。图片大概是这个样子(大小不一):
接下里开始编码:
# 导入库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
# 设置超参数
BATCH_SIZE