python利用pytorch实现图像识别分类,并搭建简单的服务器

python利用pytorch实现图像识别分类,并搭建简单的服务器


1、首先安装pytorch,打开控制台输入如下命令

pip install pytorch

2、代码如下

import torch
import torchvision
from PIL import Image
from torch import nn
from torchvision import transforms
import os
from flask import Flask, request
app = Flask(__name__)
from flask import jsonify
from werkzeug.utils import secure_filename
# 上传的图片保存路径
UPLOAD_PATH = os.path.join(os.path.dirname(__file__), 'images')

normalize = transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])  # 规范化
transforms = transforms.Compose([transforms.Resize((64, 64)),
                                     torchvision.transforms.ToTensor(),
                                     normalize
                                     ])
# 文件夹目录
base_dir = 'D:\\GraduateStudent\\Study\\python\\test\\imagedata'
# 获取当前目录下的所有文件
data_class = []


def imagerecognition(image_path):
    image = Image.open(image_path)
    image = transforms(image)
    model_ft = torchvision.models.resnet18()  # 需要使用训练时的相同模型
    in_features = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(nn.Linear(in_features, 36),
                                nn.Linear(36, 3))  # 此处也要与训练模型一致
    model = torch.load("D:\\GraduateStudent\\Study\\python\\test\\pythonProject\\best_model_yaopian.pth",
                       map_location=torch.device("cpu"))  # 选择训练后得到的模型文件
    image = torch.reshape(image, (1, 3, 64, 64))  # 修改待预测图片尺寸,需要与训练时一致
    model.eval()
    with torch.no_grad():
        output = model(image)
    # print(output)  # 输出预测结果
    # # print(int(output.argmax(1)))
    print("图片预测为:{}".format(data_class[int(output.argmax(1))]))  # 对结果进行处理,使直接显示出预测的植物种类
    return data_class[int(output.argmax(1))]


@app.route('/api/upload', methods=['POST'])
def upload_pic():
    # 来获取多个上传文件
    imgs = request.files.getlist("file_imgs")
    urls = []
    imageResult = ""
    # 上传文件夹如果不存在则创建
    if not os.path.exists(UPLOAD_PATH):
        os.mkdir(UPLOAD_PATH)
    # 循环读取上传的文件并保存
    for img in imgs:
        filename = secure_filename(img.filename)
        # print(filename)
        img.save(os.path.join(UPLOAD_PATH, filename))
        # print(UPLOAD_PATH)
        msg = UPLOAD_PATH+"/{}".format(filename)
        imagepath = os.path.abspath(msg)
        imageResult = imagerecognition(imagepath)
        urls.append(imagepath)
    respose = {
        "code": 200,
        "urls": urls,
        "imageResult": imageResult
    }
    return jsonify(respose)

if __name__ == "__main__":
    files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
    # 遍历文件列表,获取文件名
    for file in files:
        data_class.append(os.path.basename(file))
    # for filename in data_class:
    #     print(filename)
    app.run(host="192.168.43.13", port=5006, debug=False)


其中遇见报错的模块使用pip安装就行

pip install 模块名

服务器地址为:

http://192.168.43.13:5006/api/upload

可以使用postman进行测试
在这里插入图片描述
在这里插入图片描述
确保服务开启运行状态,点击send即可调用服务获取返回结果
在这里插入图片描述
pytorch模型训练的代码可以查看如下链接

https://blog.csdn.net/Low__Profile/article/details/136635940?spm=1001.2014.3001.5501

pytorch图像识别分类借鉴于

https://blog.csdn.net/Satenga/article/details/122341233

  • 7
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值