# 基于Windows平台在C++中调用Pytorch模型并实现MFC集成（以MNIST手写体数字识别为例）——附完整代码和数据

49 篇文章 40 订阅
16 篇文章 8 订阅

Python版本：Python 3.6.1

Pytorch版本：1.2.0

Libtorch：1.3

### 1.算法—MNIST手写体数字识别

MNIST是非常有名的手写体数字识别数据集，非常适合用来实战讲解，因此我们也以此作为项目的开始，通过Pytorch建立算法来识别图像中的数字。MNIST由手写体数字的图片和相对应的标签组成，如下图所示：

Pytorch的torchvision库中已经自带了该数据集，可以直接下载和导入。

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import PIL

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x=self.conv1(x)
x = F.relu(F.max_pool2d(x, 2))
x=self.conv2(x)
x = F.relu(F.max_pool2d(self.conv2_drop(x), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
100. * batch_idx / len(train_loader), loss.item()))

model.eval()
test_loss = 0
correct = 0
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(

def main():
# 用于训练的超参数设置
parser = argparse.ArgumentParser(description='PyTorch MNIST样例')
help='input batch size for training (default: 64)')
help='input batch size for testing (default: 1000)')
help='number of epochs to train (default: 10)')
help='learning rate (default: 0.01)')
help='SGD momentum (default: 0.5)')
help='disables CUDA training')
help='random seed (default: 1)')
help='how many batches to wait before logging training status')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available() # 判断是否使用GPU训练
print(use_cuda)
torch.manual_seed(args.seed) # 固定住随机种子，使训练结果可复现

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}
# 加载训练数据
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
#加载测试数据
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)

torch.save(model, "model.pth")# 保存模型参数

if __name__ == '__main__':
import time
start = time.time()
main()
end = time.time()
running_time = end-start
print('time cost : %.5f 秒' %running_time) 

import torch
import cv2
import torch.nn.functional as F
from train import Net
from torchvision import datasets, transforms
from PIL import Image

if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()  # 把模型转为test模式

img = Image.fromarray(img)
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

img = trans(img)
img = img.unsqueeze(0)  # 图片扩展多一维,[batch_size,通道,长，宽]，此时batch_size=1
img = img.to(device)
output = model(img)
pred = output.max(1, keepdim=True)[1]
pred = torch.squeeze(pred)
print('检测结果为：%d' % (pred.cpu().numpy()))

### 2.1导出序列化模型（pth转pt）

import torch
import cv2
import torch.nn.functional as F
from train import Net
from torchvision import datasets, transforms
from PIL import Image

if __name__ == '__main__':
device = torch.device('cpu') #使用cpu进行推理
model = model.to(device)
model.eval()  # 把模型转为test模式

img = Image.fromarray(img)
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

img = trans(img)
img = img.unsqueeze(0)  # 图片扩展多一维,[batch_size,通道,长，宽]
img = img.to(device)
traced_net = torch.jit.trace(model,img)
traced_net.save("model.pt")

print("模型序列化导出成功")

#### 2.2安装opencv

E:\toolplace\opencv4.1.1\opencv\build\include

E:\toolplace\opencv4.1.1\opencv\build\include\opencv2

E:\toolplace\opencv4.1.1\opencv\build\x64\vc14\lib

opencv_world411.lib

#include "stdafx.h"

#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>

using namespace cv;

int main()
{
// 读入一张图片
// 创建一个名为 "图片"窗口
namedWindow("图片");
// 在窗口中显示图片
imshow("图片", img);
// 等待6000 ms后窗口自动关闭
waitKey(6000);
return 0;
}

#### 2.3安装libtorch

E:\toolplace\libtorch\include

E:\toolplace\libtorch\lib

c10.lib
torch.lib
libprotobuf.lib
caffe2_detectron_ops.lib

### 2.4 C++中调用模型进行推理

#include "stdafx.h"
#undef UNICODE   //这个一定要加，否则会编译错误

#include <torch/script.h> // 引入libtorch头文件
#include <iostream>
#include <memory>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui_c.h>
#include <opencv2/highgui/highgui.hpp>

using namespace cv;

int main()
{
Mat image;

torch::jit::script::Module module;
try {
}
catch (const c10::Error& e) {
std::cerr << "无法加载model.pt模型\n";
return -1;
}

std::vector<int64_t> sizes = { 1, 1, image.rows, image.cols };
at::TensorOptions options(at::ScalarType::Byte);
at::Tensor tensor_image = torch::from_blob(image.data, at::IntList(sizes), options);//将opencv的图像数据转为Tensor张量数据
tensor_image = tensor_image.toType(at::kFloat);//转为浮点型张量数据
at::Tensor result = module.forward({ tensor_image }).toTensor();//推理

auto max_result = result.max(1, true);
auto max_index = std::get<1>(max_result).item<float>();
std::cerr << "检测结果为：";
std::cout << max_index << std::endl;

waitKey(6000);
}

### 2.5 MFC中集成调用


// stdafx.h : 标准系统包含文件的包含文件，
// 或是经常使用但不常更改的
// 特定于项目的包含文件

#undef UNICODE   //这个一定要加，否则会编译错误
#include <torch/script.h> // 引入libtorch头文件

#pragma once

#ifndef VC_EXTRALEAN
#define VC_EXTRALEAN            // 从 Windows 头中排除极少使用的资料
#endif

#include "targetver.h"

#define _ATL_CSTRING_EXPLICIT_CONSTRUCTORS      // 某些 CString 构造函数将是显式的

// 关闭 MFC 对某些常见但经常可放心忽略的警告消息的隐藏
#define _AFX_ALL_WARNINGS

#include <afxwin.h>         // MFC 核心组件和标准组件
#include <afxext.h>         // MFC 扩展

#include <afxdisp.h>        // MFC 自动化类

#ifndef _AFX_NO_OLE_SUPPORT
#include <afxdtctl.h>           // MFC 对 Internet Explorer 4 公共控件的支持
#endif
#ifndef _AFX_NO_AFXCMN_SUPPORT
#include <afxcmn.h>             // MFC 对 Windows 公共控件的支持
#endif // _AFX_NO_AFXCMN_SUPPORT

#include <afxcontrolbars.h>     // 功能区和控件条的 MFC 支持

#ifdef _UNICODE
#if defined _M_IX86
#pragma comment(linker,"/manifestdependency:\"type='win32' name='Microsoft.Windows.Common-Controls' version='6.0.0.0' processorArchitecture='x86' publicKeyToken='6595b64144ccf1df' language='*'\"")
#elif defined _M_X64
#pragma comment(linker,"/manifestdependency:\"type='win32' name='Microsoft.Windows.Common-Controls' version='6.0.0.0' processorArchitecture='amd64' publicKeyToken='6595b64144ccf1df' language='*'\"")
#else
#pragma comment(linker,"/manifestdependency:\"type='win32' name='Microsoft.Windows.Common-Controls' version='6.0.0.0' processorArchitecture='*' publicKeyToken='6595b64144ccf1df' language='*'\"")
#endif
#endif



#include <iostream>
#include <memory>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui_c.h>
#include <opencv2/highgui/highgui.hpp>

using namespace cv;

public:
Mat m_img;//图像
BITMAPINFO *m_imgInfo;//图像颜色表信息

//选择图像
void CMnistMFCDlg::OnBnClickedButtonChoosepic()
{
// TODO: 在此添加控件通知处理程序代码
CFileDialog dlg(TRUE, _T("*.jpg"),
NULL,
_T("image Files(*.jpg;*.jpg)|*.jpg;*.jpg|All Files (*.*)|*.*||"),
NULL);

//打开文件对话框的标题名
dlg.m_ofn.lpstrTitle = _T("选择图像 ");

if (dlg.DoModal() != IDOK)
return;

CString mPath = dlg.GetPathName();

//修改颜色表等信息
//颜色表赋值
for (int i(0); i < 256; ++i)
{
m_imgInfo->bmiColors[i].rgbBlue = i;
m_imgInfo->bmiColors[i].rgbGreen = i;
m_imgInfo->bmiColors[i].rgbRed = i;
m_imgInfo->bmiColors[i].rgbReserved = 0;
}
//头文件信息(注意由实际显示情况可得出图像原点显示在空间左下角)
m_imgInfo->bmiHeader.biSizeImage = m_img.channels() * m_img.cols * m_img.rows;

//开始绘图
CDC *pDC;
pDC = GetDlgItem(IDC_SHOW)->GetDC();
CRect imgCtrlRect;
GetDlgItem(IDC_SHOW)->GetClientRect(&imgCtrlRect);

pDC->SetStretchBltMode(COLORONCOLOR);
::StretchDIBits(pDC->GetSafeHdc(),
0, 0,
imgCtrlRect.Width(), imgCtrlRect.Height(),
0, 0,
m_img.cols, m_img.rows,
m_img.data,
m_imgInfo,
DIB_RGB_COLORS,
SRCCOPY);

ReleaseDC(pDC);
}

void CMnistMFCDlg::OnBnClickedButtonDetect()
{
// TODO: 在此添加控件通知处理程序代码
torch::jit::script::Module module;
try {
}
catch (const c10::Error& e) {
MessageBox("无法加载model.pt模型");
return;
}

std::vector<int64_t> sizes = { 1, 1, m_img.rows, m_img.cols };  //依次为batchsize、通道数、图像高度、图像宽度
at::TensorOptions options(at::ScalarType::Byte);
at::Tensor tensor_image = torch::from_blob(m_img.data, at::IntList(sizes), options);//将opencv的图像数据转为Tensor张量数据
tensor_image = tensor_image.toType(at::kFloat);//转为浮点型张量数据
at::Tensor result = module.forward({ tensor_image }).toTensor();//推理

auto max_result = result.max(1, true);
auto max_index = std::get<1>(max_result).item<float>();
CString str;
str.Format("%.0f", max_index);
MessageBox(str);
}

• 20
点赞
• 100
收藏
觉得还不错? 一键收藏
• 打赏
• 24
评论

### “相关推荐”对你有帮助么？

• 非常没帮助
• 没帮助
• 一般
• 有帮助
• 非常有帮助

¥1 ¥2 ¥4 ¥6 ¥10 ¥20

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、付费专栏及课程。