python利用pytorch实现图像识别分类,并搭建简单的服务器
1、首先安装pytorch,打开控制台输入如下命令
pip install pytorch
2、代码如下
import torch
import torchvision
from PIL import Image
from torch import nn
from torchvision import transforms
import os
from flask import Flask, request
app = Flask(__name__)
from flask import jsonify
from werkzeug.utils import secure_filename
# 上传的图片保存路径
UPLOAD_PATH = os.path.join(os.path.dirname(__file__), 'images')
normalize = transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 规范化
transforms = transforms.Compose([transforms.Resize((64, 64)),
torchvision.transforms.ToTensor(),
normalize
])
# 文件夹目录
base_dir = 'D:\\GraduateStudent\\Study\\python\\test\\imagedata'
# 获取当前目录下的所有文件
data_class = []
def imagerecognition(image_path):
image = Image.open(image_path)
image = transforms(image)
model_ft = torchvision.models.resnet18() # 需要使用训练时的相同模型
in_features = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(in_features, 36),
nn.Linear(36, 3)) # 此处也要与训练模型一致
model = torch.load("D:\\GraduateStudent\\Study\\python\\test\\pythonProject\\best_model_yaopian.pth",
map_location=torch.device("cpu")) # 选择训练后得到的模型文件
image = torch.reshape(image, (1, 3, 64, 64)) # 修改待预测图片尺寸,需要与训练时一致
model.eval()
with torch.no_grad():
output = model(image)
# print(output) # 输出预测结果
# # print(int(output.argmax(1)))
print("图片预测为:{}".format(data_class[int(output.argmax(1))])) # 对结果进行处理,使直接显示出预测的植物种类
return data_class[int(output.argmax(1))]
@app.route('/api/upload', methods=['POST'])
def upload_pic():
# 来获取多个上传文件
imgs = request.files.getlist("file_imgs")
urls = []
imageResult = ""
# 上传文件夹如果不存在则创建
if not os.path.exists(UPLOAD_PATH):
os.mkdir(UPLOAD_PATH)
# 循环读取上传的文件并保存
for img in imgs:
filename = secure_filename(img.filename)
# print(filename)
img.save(os.path.join(UPLOAD_PATH, filename))
# print(UPLOAD_PATH)
msg = UPLOAD_PATH+"/{}".format(filename)
imagepath = os.path.abspath(msg)
imageResult = imagerecognition(imagepath)
urls.append(imagepath)
respose = {
"code": 200,
"urls": urls,
"imageResult": imageResult
}
return jsonify(respose)
if __name__ == "__main__":
files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
# 遍历文件列表,获取文件名
for file in files:
data_class.append(os.path.basename(file))
# for filename in data_class:
# print(filename)
app.run(host="192.168.43.13", port=5006, debug=False)
其中遇见报错的模块使用pip安装就行
pip install 模块名
服务器地址为:
http://192.168.43.13:5006/api/upload
可以使用postman进行测试
确保服务开启运行状态,点击send即可调用服务获取返回结果
pytorch模型训练的代码可以查看如下链接
https://blog.csdn.net/Low__Profile/article/details/136635940?spm=1001.2014.3001.5501
pytorch图像识别分类借鉴于
https://blog.csdn.net/Satenga/article/details/122341233