给图片分类标注的辅助程序

给图片分类标注的辅助程序

最近需要制作数据集,但感觉网上的标签软件功能太多反而用不上,有的用不上的功能还会耽搁本身的步骤

于是用Python,调用easygui写了一个简单的标注工具,可以进行简单的分类标记。也可以修改上一个标错了的选项。

分享出来互相借鉴,我为人人人人为我

只需要修改_find_img()函数,也就是寻找你的图片路径和marking()函数中csv文件访问代码(改成你的数据存储格式)就可以使用。

import os
import pandas as pd
import easygui as G
import cv2
import matplotlib.pyplot as plt

#####################################################################
#确定标记的文件名
outputfile = '1.csv'
#要标记的种类填入
mark_level = '猫狗分类'
#可视化显示标记进度
visualization = True
##########################################################

#这这可以添加多个分类以供选择

markdic1 = {'未知':0,
            '猫':1,
            '狗':2
            }
markdic2 = {'未知':0,
            '持续躺':1,
            '持续站':2
            }

class mark:
    def __init__(self,out_file_name,mark_level):
        if mark_level == '猫狗分类':
            self.markdic = markdic1
        if mark_level == '姿态':
            self.markdic = markdic2
        self.out_file_name = out_file_name
        self.mark_level = mark_level
        print('准备开始标记 '+mark_level)

    def marking(self):
        self._read_csv()
        for idx,state in enumerate(self.csv_data.loc[:,6]):
            if not state % 100 >= 0:
                imgpath = self._find_img(self.csv_data.loc[idx,0],self.csv_data.loc[idx,1])
                self._creat_temp_img(imgpath)
                ans = G.buttonbox(self._get_progress(idx),
                                  title=self.mark_level+'   '+imgpath,
                                  image='temp.jpg',
                                  choices= list(self.markdic.keys())+['上个错了'])
                if ans == '上个错了':
                    modify = G.buttonbox('上一个该是什么?',
                                      title=self.mark_level + '   修改错误',
                                      choices=list(self.markdic.keys()))
                    self.csv_data.loc[idx-1, 6] = self.markdic[modify]
                    print('修改上一个为 {}'.format(modify),end=',并且修改的')
                    self._save_csv()
                    ans = G.buttonbox(self._get_progress(idx),
                                      title=self.mark_level + '   ' + imgpath,
                                      image='temp.jpg',
                                      choices=list(self.markdic.keys()))
                if ans in self.markdic.keys():
                    print('将',imgpath,
                          '定义为:',ans)
                    self.csv_data.loc[idx, 6] = self.markdic[ans]
                    if idx%10 == 9:
                        self._save_csv()
                else:
                    break
        self._save_csv()
        os.remove('temp.jpg')

    def report(self,vis):
        newdic1 = dict(zip(self.markdic.values(),self.markdic.keys()))
        newdic2 = dict([(v,0) for (_,v) in self.markdic.items()])
        for temp in self.csv_data.loc[:,6]:
            if temp in list(newdic2.keys()):
                newdic2[temp] += 1
        newdic3 = {}
        for temp in newdic2.keys():
            if not vis:
                print(newdic1[temp],':',newdic2[temp])
            newdic3[newdic1[temp]] = newdic2[temp]
        if vis:
            self._visualization(newdic3)

    def _read_csv(self):
        self.csv_data = pd.read_csv(self.out_file_name,header=None)
        self.csv_data_len = len(self.csv_data)

    def _visualization(self,dic):
        print('可视化标注进度')
        val = list(k for (v,k) in dic.items())
        sca = range(len(dic))
        ind = list(v for (v,k) in dic.items())
        plt.rcParams['font.sans-serif'] = ['SimHei']
        plt.ylim(0,1000)
        plt.bar(sca, val)
        plt.xticks(sca, ind)
        plt.title('各标签情况')
        plt.show()

    def _save_csv(self):
        self.csv_data.to_csv(self.out_file_name, mode='w', index=False, header=False)
        print('csv文件已保存')

    def _find_img(self,vid,img):
        return os.path.join('dataset',str(vid).zfill(6),str(vid).zfill(6)+'_'+str(img).zfill(6)+'.jpg')

    def _get_progress(self,idx):
        progress = '进度:========================================'.replace('=','>',int(40*(idx/self.csv_data_len)))
        return progress+'  {}/{}'.format(idx,self.csv_data_len)

    def _creat_temp_img(self,imgname):
        img = cv2.imread(imgname)
        img = cv2.resize(img,(int(384*2),int(216*2)))
        cv2.imwrite('temp.jpg',img)


if __name__ == '__main__':
    Mark = mark(out_file_name=outputfile,mark_level=mark_level)
    Mark.marking()
    Mark.report(vis=visualization)

需要的模块主要是easygui和cv2(这是opencv-python)模块,简单地pip安装提示需要安装的模块就行

附上效果图,因为数据集还没发布所以打码了
因为数据集还没发布所以打码了
以及标记数据的统计,因为数据集还没发布所以也打码了
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值