整个项目的思路
训练模型
- 使用pytorch对resnet18 进行迁移学习,实现对自己的数据进行图像分类。需要将最后一个全连接层中的输出节点数目修改,因为我的数据集中包含有5中图像,所以这里的输出节点数目修改成5
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 使用GPU
# 加载预训练的resnet18模型
net = torchvision.models.resnet18(pretrained=True)
# 冻结原网络参数,仅训练最后新替换的全连接层
for param in net.parameters():
param.requires_grad = False
num_ftrs = net.fc.in_features # 原网络最后一层的输入维度
net.fc = nn.Linear(num_ftrs, 5) # 替换新的连接层,输出改为5,预测5个类别
net = net.to(DEVICE)
- 然后resnet18模型训练好之后,保存训练过程中准确率最高的模型。
# 保存模型参数net.state_dict()
torch.save(net.state_dict(), 'net_dict.pt')
# 保存完整模型
torch.save(net, 'net.pt')
- 然后可以随便找一张image,用保存的训练好的模型进行预测
Flask—python服务器
Flask和Django都是web框架。可以将模型发布在服务器(这里使用的是本地服务器)。在对应的URL中实现对模型的调用。
app = flask.Flask(__name__)
...
...
@app.route("/predict", methods=["POST"])
def predict():
...
...
出现下图,说明flask服务开启成功
向浏览器中输入该网址,然后可以在终端向服务器,以POST方式传过去待识别的图像。并接收从服务器传过来的识别结果。这里的终端暂时使用的是anaconda虚拟环境中的python.exe,来执行.py文件中的代码。后续将Android作为终端。
遇到的问题及解决方法
错误一:
在启动flask服务程序的下段代码中,
preds = F.softmax(model(image), dim=1)
results = torch.topk(preds.cpu().data, k=3, dim=1)
# Loop over the results and add them to the list of returned predictions
for prob, label in zip(results[0][0], results[1][0]):
print(label) # tensor(162)
label_name = idx2label[label]
r = {"label": label_name, "probability": float(prob)}
data['predictions'].append(r)
这个地方报错:
KeyError: tensor(162)
错误原因: label_name = idx2label[label]
idx2label是一个字典{key0: value0, key1: value1, key2: value2…}, 比如{0: ‘cardboard’, 1: ‘glass’, 2: ‘metal’, 3: ‘paper’, 4: ‘plastic’}。可是输出label,发现label并不是一个数,而是tensor。 所以需要将tensor转换为数值。
修改:将label_name = idx2label[label] ---->label_name = idx2label[int(label)]
错误二:
在用anaconda虚拟环境中的python.exe执行simple_request.py文件时,动态对函数参数赋值传入待预测图像的文件路径时,报错
image = open(image_path, 'rb').read() OSError: [Errno 22] Invalid argument: "'e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg'"
修改
将>python E:/PROJECT/PycharmProjects/pt/simple_request.py --file='e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg
中的文件路径改为>python E:/PROJECT/PycharmProjects/pt/simple_request.py --file=e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg
即去掉图片路径中的单引号,光是这个小小的错误让我头疼了一整天。。
完整代码有时间会传到github上的。。