https://github.com/open-mmlab/mmclassification
目前官方还未提供infer代码,自己写了一个但是报了一个错误'LinearClsHead' object has no attribute 'simple_test',需要改动源码
在linear_head.py 增加一个函数即可:
def simple_test(self, x):
cls_score = self.fc(x)
return cls_score
然后前向代码infer.py
# -*- coding: utf-8 -*-
# @Time : 2020/10/3 下午8:35
# @Author : zxq
# @File : model_infer.py
# @Software: PyCharm
import os
import cv2
import mmcv
import numpy as np
import torch
from mmcv import Config
from mmcls.models import build_classifier
if __name__ == '__main__':
cfg = Config.fromfile('../../configs/imagenet/ciga_call_cfg.py')
data_path = '/home/zxq/PycharmProjects/data/ciga_call/test2'
weight_path = '../../work_dir/epoch_100.pth'
model = build_classifier(cfg.model)
model.eval()
save_path = os.path.join(os.path.dirname(cfg.data.test.data_prefix), 'test_result')
mmcv.mkdir_or_exist(save_path)
mean_value = None
std_value = None
for step_ in cfg.test_pipeline:
if step_['type'] is 'Normalize':
mean_value = np.array(step_['mean'])
std_value = np.array(step_['std'])
img_name_list = os.listdir(data_path)
for img_name in img_name_list:
img_dir = os.path.join(data_path, img_name)
print(img_dir)
img = cv2.imread(img_dir)
# 1, resize
img_resized = mmcv.imresize(img, (256, 256))
# 2, Normalize
img_normalized = mmcv.imnormalize(img_resized, mean_value, std_value)
# 3, switch dim and to tensor
input_data = torch.Tensor(np.transpose(img_normalized, [2, 0, 1]))
# 4, add batch dim
batch_data = torch.unsqueeze(input_data, 0)
# 4, infer
model.load_state_dict(torch.load(weight_path, map_location='cpu')['state_dict'])
model_output = model(batch_data, return_loss=False).detach().numpy()
cls_output = np.argmax(model_output, axis=1)
print(cls_output)