好不容易搞通的,记录一下以免忘记
1.准备工作
1.1 libtorch配置
类似配置opencv,添加包含目录和库目录
1.2 pt文件准备
本次测试的是基于vggnet的猫狗识别。
1.2.1 pth_to_pt
import torch
import torchvision.models as models
import os
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型
path_state_dict = os.path.join("..","data", "VGG_16_weight.pth")
model = models.vgg16()
num_ftrs = model.classifier._modules["6"].in_features
model.classifier._modules["6"] = nn.Linear(num_ftrs, 2)
model.to(device)
pretrained_state_dict