数据集路径:https://download.csdn.net/download/Ji_HON/88590044
本人常用该工程测试GPU Pytorch环境的搭建
包含了一个训练预测模型的完整流程
train.py:
1、数据集加载(自定义加载数据集的方法,并分为训练集和测试集)
dataset =torchvision.datasets.ImageFolder(root='G:/LJH/DATASETS/flower_photos',transform=train_transform)
train_loader =DataLoader(train_dataset,batch_size=4, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。
valid_loader =DataLoader(valid_dataset,batch_size=4, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。
2、网络加载(加载网络并加载预训练权重)
model = resnet50()
model.load_state_dict(torch.load('weigths/resnet50.pth'))
3、网络训练与结果保存
for epoch in range(1, 9):
train(model, DEVICE, train_loader, optimizer, epoch)
test(model, DEVICE, valid_loader)
torch.save(model, 'weigths/ResNetFlowermodel-epoch8.pth')
其中包含了训练准确率的可视化显示
flower_predict.py:
1、预测图片的读取
img = Image.open("test.jpg")
2、模型的加载与预测
model=torch.load('weigths/ResNetFlowermodel-epoch8.pth',map_location='cpu')
#model.to(DEVICE)
flowers=['雏菊','蒲公英','玫瑰','向日葵','郁金香']
with torch.no_grad():
output = torch.squeeze(model(img))
print(output)
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print(flowers[predict_cla])
plt.show()