Libtorch1.4加载自定义图像分类模型(VS 2019)

环境:

  1. win10
  2. Visual Studio2019(VC16)
  3. OpenCV:4.2(和opencv_contrib以前编译)
  4. Libtorch:1.4.0(cpu版本)
  5. pytorch:1.4.0
  6. python3.7
  7. cmake
  8. 数据集:flower_photos(5种类型)

1. 构建分类模型

使用torchvision中已有的模型进行迁移学习,构建自定义模型,代码如下:

# -*- coding: utf-8 -*-
"""
Created on Tue Feb 25 12:58:25 2020

@author: zhou-wenqing

图像分类任务
"""
from PIL import Image
import torch
from torchvision import models, transforms, datasets
from torch import nn
import torch.nn.functional as F
from torchsummary import summary
from skorch import NeuralNetClassifier
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 12
import time
import copy
import numpy as np
import argparse


#%% 根据torchvision自带的模型进行迁移学习,迁移学习的方式包括:
# 1)冻结参数(卷积层只作提取图像特征用,权重使用imagenet预训练权重,不再参与梯度更新)
# 2)修改输出全连接层数,和自定义数据集所需分类类别数量对应

def create_model(model_name, # 模型名称
                     num_classes, # 类别数量
                     feature_extract:bool, # 是否作特征提取
                     use_pretrained=True,  # 是否加载预训练权重 
                     ):
    model_ft = None
    input_size = 0
    
    def set_parameter_requires_grad(model, feature_extracting):
        if feature_extracting:
            for param in model.parameters():
                param.requires_grad = False
            
    if model_name == "resnet":
        """ 
        Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ 
        Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ 
        VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ 
        Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ 
        Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ 
        Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size


#%% 
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, plot=False):
    since = time.time()

    val_acc_history = []
    train_loss = []
    train_acc = []
    iters = []
    if plot:
        plt.ion()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        # 每个 epoch 包含 training 和 validation phase.
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for idx, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # 计算模型输出及 loss.
                    # 对于 inception 模型,训练时,其还包括一个辅助 loss;
                    #     最终的 loss 是辅助 loss 和最终输出 loss 的两者之和.
                    #     但,测试时,只考虑最终输出的 loss.
                    if is_inception and phase == 'train':
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        
                        loss.backward()
                        optimizer.step()
                    # print(f'epoch:{epoch} | batch:{batch} | iters:{iters} | batch train loss:{train_loss} | batch train acc: {train_acc}')
                    
                # statistics
                # 每个batch的loss和预测正确的数量相加起来
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                    
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            if plot:
                if phase=='train':
                    train_loss.append(epoch_loss)
                    train_acc.append(epoch_acc)
                    iters.append(epoch)
                    
                if (epoch+1) % 1 == 0:  # plotting
                    plt.cla()
                    
                    plt.subplot(121)
                    plt.plot(iters, train_loss, 'r', label='train loss')
                    
                    plt.title('Train Loss')
                    plt.xlabel('epochs')
                    plt.legend()
                    
                    plt.subplot(122)
                    plt.plot(iters, train_acc,'b', label='train acc')
                    plt.xlabel('epochs')
                    plt.title('Train Acc')
                    
                    plt.legend()
            
                plt.ioff()
                plt.show()
                
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, val_acc_history

if __name__=='__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='alexnet',
                        help='模型名称,alexnet,resnet,vgg,squeezenet,densnet,inception')
    parser.add_argument('--data', type=str, default=r"D:\Datasets\flower_photos", help='数据集根路径')
    parser.add_argument('--epochs', type=int, default=2)
    parser.add_argument('--extract', type=bool, default=True)
    parser.add_argument('--pretrained', type=bool, default=True)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--num-classes', type=int, default=5)
    parser.add_argument('--train-ratio', type=float, default=0.8,
                        help='训练集比例')
    
    opt = parser.parse_args()
    print(opt)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # 模型初始化
    model_ft, input_size = create_model(model_name=opt.model, 
                                        num_classes=opt.num_classes, 
                                        feature_extract=opt.extract, 
                                        use_pretrained=opt.pretrained)
    #%% Loading image dataset
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop((input_size,input_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((input_size,input_size),
            # transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    full_dataset = datasets.ImageFolder(root=opt.data, 
                                     transform=data_transforms['train'])
    print('数据集总长度:', len(full_dataset))
    
    # 分割数据集
    train_size = int(opt.train_ratio * len(full_dataset))
    test_size = len(full_dataset) - train_size
    
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    print('训练集总长度:', len(train_dataset))
    print('验证集总长度:', len(test_dataset))
    image_datasets = {'train':train_dataset,
                      'val':test_dataset}
    
    # Create training and validation dataloaders
    dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batch_size, shuffle=True, num_workers=0) for x in ['train', 'val']}

    # 模型放于 device
    model_ft = model_ft.to(device)
    # 打印实例化后的模型
    print(summary(model_ft, (3,input_size, input_size)))
    # 收集待优化/待更新的参数.
    # 如果是 finetuning,则更新全部网络参数;
    # 如果是 feature extraction,则只更新 requires_grad=True 的参数.
    params_to_update = model_ft.parameters()
    print("Params to learn:")
    if opt.extract:
        params_to_update = []
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                print("\t",name)
    
    # 所有参数均是待优化参数.
    optimizer_ft = torch.optim.SGD(params_to_update, lr=0.001, momentum=0.9)
    
    # 设置 loss 函数
    criterion = nn.CrossEntropyLoss()
    
    # Train and evaluate
    model_ft, hist = train_model(model_ft, 
                                 dataloaders_dict, 
                                 criterion, 
                                 optimizer_ft, 
                                 num_epochs=opt.epochs, 
                                 is_inception=(opt.model=="inception"))
    torch.save(model_ft,'custom_model.pt')
    ## 转化jittrace
    # An example input you would normally provide to your model's forward() method.
    example = torch.rand(1, 3, 224, 224)
    
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(model_ft, example)
    traced_script_module.save('custom_traced_model.pt')

2. 转化为torch.jit.ScriptModule

这部分内容在前面代码中已经实现了,这里再次强调一下,关于Torch Script模型的转化可以参考官方教程:https://pytorch.org/tutorials/advanced/cpp_export.html#a-minimal-c-application

3. 编写CMakeLists.txt文件

pytorch官方文档还不详细,没有介绍怎么加载图像,网上好多教程都是在Linux环境操作的,环境设置比Windows环境方便,Opencv库路径在CMakeLists.txt文件中需要手动指定,完整文件内容如下:

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)

set(OpenCV_DIR "E:\\ScientificComputing\\opencv-4.2.0\\build\\install")

find_package(Torch REQUIRED)        # 查找libtorch
find_package(OpenCV REQUIRED)       # 查找OpenCV

if(NOT Torch_FOUND)
    message(FATAL_ERROR "Pytorch Not Found!")
endif(NOT Torch_FOUND)

message(STATUS "Pytorch status:")
message(STATUS "    libraries: ${TORCH_LIBRARIES}")

message(STATUS "OpenCV library status:")
message(STATUS "    version: ${OpenCV_VERSION}")
message(STATUS "    libraries: ${OpenCV_LIBS}")
message(STATUS "    include path: ${OpenCV_INCLUDE_DIRS}")

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)

# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
if (MSVC)
  file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
  add_custom_command(TARGET example-app
                     POST_BUILD
                     COMMAND ${CMAKE_COMMAND} -E copy_if_different
                     ${TORCH_DLLS}
                     $<TARGET_FILE_DIR:example-app>)
endif (MSVC)

4. 编写example-app.cpp

该程序实现:1)加载自定义script module;2)预测图像,输出最大索引值

#include <torch/script.h> // One-stop header.
#include <ATen/ATen.h>
#include <iostream>
#include <memory>
#include <opencv2/opencv.hpp>
#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/highgui.hpp>

using namespace std;
using namespace cv;

int main(int argc, const char *argv[])
{
	if (argc != 3)  // Here we need 2 arguments
	{
		std::cerr << "usage: example-app <image-path> <path-to-exported-script-module>\n";
		return -1;
	}

	torch::jit::script::Module module;
	try
	{
		// Deserialize the ScriptModule from a file using torch::jit::load().
		module = torch::jit::load(argv[2]);
	}
	catch (const c10::Error &e)
	{
		std::cerr << "error loading the model\n";
		return -1;
	}
	std::cout << "Loading model succesfully...\n";
	//杈撳叆鍥惧儚
    auto image = cv::imread(argv[1],cv::ImreadModes::IMREAD_COLOR);
    cv::Mat image_transfomed;
    cv::resize(image, image_transfomed, cv::Size(224, 224));
    cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);

    // convert cv::Mat to at::Tensor (see https://pytorch.org/cppdocs/api/namespace_at.html#namespace-at)
    torch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols,3},torch::kByte);
    tensor_image = tensor_image.permute({2,0,1});
    tensor_image = tensor_image.toType(torch::kFloat);
    tensor_image = tensor_image.div(255);

    tensor_image = tensor_image.unsqueeze(0);
	// Execute the model and turn its output into a tensor.
	// at::Tensor output = module.forward(inputs).toTensor();
	at::Tensor output = module.forward({tensor_image}).toTensor();
	// cout << "output:" << output << endl;
	// std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
	
	auto max_result = output.max(1, true);
    auto max_index = std::get<1>(max_result).item<float>();

	cout << "max index predicted: " << max_index << endl;
}

5. 编译运行

cd example-app
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH="F\\libtorch" ..
cmake --build . --config Release 

没有问题的话便在Release目录下得到编译好的程序已经一些动态库:
在这里插入图片描述
在build目录下执行命令:

./Release/example-app.exe "C:\Users\zhou-\Pictures\sunflower.jpg" custom_traced_model.pt

huiti会提示找不到opencv的相关库:
在这里插入图片描述
将提示缺失的相关库放置example-app.exe所在目录即可:
在这里插入图片描述

在这里插入图片描述

总结

  1. 最新版本VS 2019可以正常编译链接libtorch
  2. 编译opencv和libtorch最好要一致,避免出现不兼容的问题
  3. cmakelists文件编写需要找到opencv和libtorch库
  4. 运行时出现找不到动态库的情况可以手动复制相关库到可执行文件目录下(笔者设置了系统环境变量还是找不到,不知什么原因)
  5. MinGW目前是不支持Windows版本的libtorch的,之前笔者一直使用的都是mingw,但是学习libtorch上目前只能使用VC(VS 2019对应的VC版本是16)
  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值