ppcls,飞桨分类改写增加自己的预测txt 格式 python3 tools/infer.py

标题修改1,增加一张图片预测一张的函数输出输入修改


代码位置: PaddleClas-release-2.5\tools\infer.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))

from ppcls.utils import config
from ppcls.engine.engine import Engine

if __name__ == "__main__":
    args = config.parse_args()
    config = config.get_config(
        args.config, overrides=args.override, show=False)
    engine = Engine(config, mode="infer")
    # engine.infer()
    engine.infer_one()


    # F:\syy\code\PaddleClas\ppcls\data\postprocess\threshoutput.py  返回的分类结果在这里改,请注意

# python3 tools/infer.py \
#     -c ./ppcls/configs/PPLC_breast.yaml \
#     -o Global.pretrained_model=output/breast_2class/3label_best/best_model

修改2,增加代码功能

位置
PaddleClas\ppcls\engine\engine.py
在默认的函数后面再加入新的infer_one
@paddle.no_grad()
def infer_one(self): # 默认的 infer


@paddle.no_grad()
def infer_one(self): # 一个预测一个二分类结果

位置
PaddleClas\ppcls\engine\engine.py
在默认的函数后面再加入新的infer_one
@paddle.no_grad()
    def infer(self):  # 默认的  infer
			...
			...
			
@paddle.no_grad()
    def infer_one(self):  # 一个预测一个二分类结果

######################################
@paddle.no_grad()
    def infer_one(self):  # 一个预测一个二分类结果
        assert self.mode == "infer" and self.eval_mode == "classification"
        total_trainer = dist.get_world_size()
        local_rank = dist.get_rank()
        # 重新建立一个 image_list
        #image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # add5_pred1.txt
        '''
        python3 tools/infer.py \
           -c ./ppcls/configs/PULC/hebin_breast/breast_two_class.yaml \
           -o Global.pretrained_model=./output/maskadd15_train3/PPLCNet_x1_0/best_model
        mask_train2  maskadd5_train2
        '''
        
        # 输出名字改动
        fw = open('/home/syy/code/PaddleClas/school_pre/add15_pred3.txt',encoding="utf8",mode='w')
        print("=============》 建立文件")
        
        
        # 输入 txt 修改,图片路径修改
        imgdir = "/home/syy/data/school_breast_class/mask_merge_crop15/"
        # imgdir = "/home/syy/data/school_breast_class/mask_merge_crop5/"
        
        f=open("/home/syy/data/school_breast_class/mask_merge_crop5/test3.txt")
        # f = open("/home/syy/data/school_breast_class/mask_crop/test2.txt")

        
        ll = f.readlines()
        image_list=[]
        label_list=[]
        for n in ll:
            image_list.append(imgdir +n.strip().split(" ")[0])
            label_list.append(n.strip().split(" ")[1])
        # ['/home/syy/data/school_breast_class/mask_crop/test/0/003_0_0.jpg',
        #  '/home/syy/data/school_breast_class/mask_crop/test/0/006_0_0.jpg',]


        # data split
        image_list = image_list[local_rank::total_trainer]
        label_list = label_list[local_rank::total_trainer]
        print(label_list)
        assert(len(label_list)==len(image_list))
        print(len(label_list),len(image_list))

        batch_size = self.config["Infer"]["batch_size"]
        self.model.eval()
        batch_data = []
        image_file_list = []

        for idx, image_file in enumerate(image_list):
            print(idx)
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in self.preprocess_func:
                x = process(x)
                
            "/home/syy/data/school_breast_class/mask_crop/test/0/003_0_0.jpg"
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)

                if self.amp and self.amp_eval:
                    with paddle.amp.auto_cast(
                            custom_black_list={
                                "flatten_contiguous_range", "greater_than"
                            },
                            level=self.amp_level):
                        out = self.model(batch_tensor)
                else:
                    out = self.model(batch_tensor)

                if isinstance(out, list):
                    out = out[0]
                if isinstance(out, dict) and "Student" in out:
                    out = out["Student"]
                if isinstance(out, dict) and "logits" in out:
                    out = out["logits"]
                if isinstance(out, dict) and "output" in out:
                    out = out["output"]
      
                result = self.postprocess_func(out, image_file_list)
                # print(result)  #多分类  输出是排序好的从大到小
                # [{'class_ids': [0, 4, 3, 2, 1], 'scores': [0.43975, 0.34794, 0.10294, 0.0656, 0.04377], 'file_name':

                # 保存score  图片名字
                # print(result)
                # [{'scores': [0.8015347], 'file_name': '/home/syy/data/school_breast_class/mask_crop/test/0/003_0_0.jpg'}
                print("============>长度",len(result))
                for i,line in enumerate(result):
                    predlabel = line["pred"]
                    print(label_list[idx], predlabel, line['scores'][line["pred"]])
                    fw.write(
                        str(label_list[idx]) + " " + str(predlabel) + " " + str(line['scores'][line["pred"]]) + "\n")

                batch_data.clear()
                image_file_list.clear()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值