图像分类识别入门训练模型(PyTorch)

文章介绍了使用PyTorch实现的ResNet模型在Flower数据集上的训练过程,包括数据加载、网络加载、训练与结果保存,以及模型预测功能。作者分享了如何自定义数据集、加载预训练权重以及进行图像预测的代码片段。
摘要由CSDN通过智能技术生成

ResNet_Flower

数据集路径:https://download.csdn.net/download/Ji_HON/88590044

本人常用该工程测试GPU Pytorch环境的搭建

包含了一个训练预测模型的完整流程

train.py:

1、数据集加载(自定义加载数据集的方法,并分为训练集和测试集)

dataset =torchvision.datasets.ImageFolder(root='G:/LJH/DATASETS/flower_photos',transform=train_transform)
train_loader =DataLoader(train_dataset,batch_size=4, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。
valid_loader =DataLoader(valid_dataset,batch_size=4, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。

2、网络加载(加载网络并加载预训练权重)

model = resnet50()
model.load_state_dict(torch.load('weigths/resnet50.pth'))

3、网络训练与结果保存

for epoch in range(1, 9):
    train(model,  DEVICE, train_loader, optimizer, epoch)
    test(model, DEVICE, valid_loader)
torch.save(model, 'weigths/ResNetFlowermodel-epoch8.pth')

其中包含了训练准确率的可视化显示

flower_predict.py:

1、预测图片的读取

img = Image.open("test.jpg")

2、模型的加载与预测

model=torch.load('weigths/ResNetFlowermodel-epoch8.pth',map_location='cpu')
#model.to(DEVICE)
flowers=['雏菊','蒲公英','玫瑰','向日葵','郁金香']
with torch.no_grad():
    output = torch.squeeze(model(img))
    print(output)
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
    print(flowers[predict_cla])
plt.show()

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值