标题修改1,增加一张图片预测一张的函数输出输入修改
代码位置: PaddleClas-release-2.5\tools\infer.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
from ppcls.utils import config
from ppcls.engine.engine import Engine
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(
args.config, overrides=args.override, show=False)
engine = Engine(config, mode="infer")
# engine.infer()
engine.infer_one()
# F:\syy\code\PaddleClas\ppcls\data\postprocess\threshoutput.py 返回的分类结果在这里改,请注意
# python3 tools/infer.py \
# -c ./ppcls/configs/PPLC_breast.yaml \
# -o Global.pretrained_model=output/breast_2class/3label_best/best_model
修改2,增加代码功能
位置
PaddleClas\ppcls\engine\engine.py
在默认的函数后面再加入新的infer_one
@paddle.no_grad()
def infer_one(self): # 默认的 infer
…
…
@paddle.no_grad()
def infer_one(self): # 一个预测一个二分类结果
位置
PaddleClas\ppcls\engine\engine.py
在默认的函数后面再加入新的infer_one
@paddle.no_grad()
def infer(self): # 默认的 infer
...
...
@paddle.no_grad()
def infer_one(self): # 一个预测一个二分类结果
######################################
@paddle.no_grad()
def infer_one(self): # 一个预测一个二分类结果
assert self.mode == "infer" and self.eval_mode == "classification"
total_trainer = dist.get_world_size()
local_rank = dist.get_rank()
# 重新建立一个 image_list
#image_list = get_image_list(self.config["Infer"]["infer_imgs"])
# add5_pred1.txt
'''
python3 tools/infer.py \
-c ./ppcls/configs/PULC/hebin_breast/breast_two_class.yaml \
-o Global.pretrained_model=./output/maskadd15_train3/PPLCNet_x1_0/best_model
mask_train2 maskadd5_train2
'''
# 输出名字改动
fw = open('/home/syy/code/PaddleClas/school_pre/add15_pred3.txt',encoding="utf8",mode='w')
print("=============》 建立文件")
# 输入 txt 修改,图片路径修改
imgdir = "/home/syy/data/school_breast_class/mask_merge_crop15/"
# imgdir = "/home/syy/data/school_breast_class/mask_merge_crop5/"
f=open("/home/syy/data/school_breast_class/mask_merge_crop5/test3.txt")
# f = open("/home/syy/data/school_breast_class/mask_crop/test2.txt")
ll = f.readlines()
image_list=[]
label_list=[]
for n in ll:
image_list.append(imgdir +n.strip().split(" ")[0])
label_list.append(n.strip().split(" ")[1])
# ['/home/syy/data/school_breast_class/mask_crop/test/0/003_0_0.jpg',
# '/home/syy/data/school_breast_class/mask_crop/test/0/006_0_0.jpg',]
# data split
image_list = image_list[local_rank::total_trainer]
label_list = label_list[local_rank::total_trainer]
print(label_list)
assert(len(label_list)==len(image_list))
print(len(label_list),len(image_list))
batch_size = self.config["Infer"]["batch_size"]
self.model.eval()
batch_data = []
image_file_list = []
for idx, image_file in enumerate(image_list):
print(idx)
with open(image_file, 'rb') as f:
x = f.read()
for process in self.preprocess_func:
x = process(x)
"/home/syy/data/school_breast_class/mask_crop/test/0/003_0_0.jpg"
batch_data.append(x)
image_file_list.append(image_file)
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data)
if self.amp and self.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=self.amp_level):
out = self.model(batch_tensor)
else:
out = self.model(batch_tensor)
if isinstance(out, list):
out = out[0]
if isinstance(out, dict) and "Student" in out:
out = out["Student"]
if isinstance(out, dict) and "logits" in out:
out = out["logits"]
if isinstance(out, dict) and "output" in out:
out = out["output"]
result = self.postprocess_func(out, image_file_list)
# print(result) #多分类 输出是排序好的从大到小
# [{'class_ids': [0, 4, 3, 2, 1], 'scores': [0.43975, 0.34794, 0.10294, 0.0656, 0.04377], 'file_name':
# 保存score 图片名字
# print(result)
# [{'scores': [0.8015347], 'file_name': '/home/syy/data/school_breast_class/mask_crop/test/0/003_0_0.jpg'}
print("============>长度",len(result))
for i,line in enumerate(result):
predlabel = line["pred"]
print(label_list[idx], predlabel, line['scores'][line["pred"]])
fw.write(
str(label_list[idx]) + " " + str(predlabel) + " " + str(line['scores'][line["pred"]]) + "\n")
batch_data.clear()
image_file_list.clear()