给图片分类标注的辅助程序
最近需要制作数据集,但感觉网上的标签软件功能太多反而用不上,有的用不上的功能还会耽搁本身的步骤
于是用Python,调用easygui写了一个简单的标注工具,可以进行简单的分类标记。也可以修改上一个标错了的选项。
分享出来互相借鉴,我为人人人人为我
只需要修改_find_img()函数,也就是寻找你的图片路径和marking()函数中csv文件访问代码(改成你的数据存储格式)就可以使用。
import os
import pandas as pd
import easygui as G
import cv2
import matplotlib.pyplot as plt
#####################################################################
#确定标记的文件名
outputfile = '1.csv'
#要标记的种类填入
mark_level = '猫狗分类'
#可视化显示标记进度
visualization = True
##########################################################
#这这可以添加多个分类以供选择
markdic1 = {'未知':0,
'猫':1,
'狗':2
}
markdic2 = {'未知':0,
'持续躺':1,
'持续站':2
}
class mark:
def __init__(self,out_file_name,mark_level):
if mark_level == '猫狗分类':
self.markdic = markdic1
if mark_level == '姿态':
self.markdic = markdic2
self.out_file_name = out_file_name
self.mark_level = mark_level
print('准备开始标记 '+mark_level)
def marking(self):
self._read_csv()
for idx,state in enumerate(self.csv_data.loc[:,6]):
if not state % 100 >= 0:
imgpath = self._find_img(self.csv_data.loc[idx,0],self.csv_data.loc[idx,1])
self._creat_temp_img(imgpath)
ans = G.buttonbox(self._get_progress(idx),
title=self.mark_level+' '+imgpath,
image='temp.jpg',
choices= list(self.markdic.keys())+['上个错了'])
if ans == '上个错了':
modify = G.buttonbox('上一个该是什么?',
title=self.mark_level + ' 修改错误',
choices=list(self.markdic.keys()))
self.csv_data.loc[idx-1, 6] = self.markdic[modify]
print('修改上一个为 {}'.format(modify),end=',并且修改的')
self._save_csv()
ans = G.buttonbox(self._get_progress(idx),
title=self.mark_level + ' ' + imgpath,
image='temp.jpg',
choices=list(self.markdic.keys()))
if ans in self.markdic.keys():
print('将',imgpath,
'定义为:',ans)
self.csv_data.loc[idx, 6] = self.markdic[ans]
if idx%10 == 9:
self._save_csv()
else:
break
self._save_csv()
os.remove('temp.jpg')
def report(self,vis):
newdic1 = dict(zip(self.markdic.values(),self.markdic.keys()))
newdic2 = dict([(v,0) for (_,v) in self.markdic.items()])
for temp in self.csv_data.loc[:,6]:
if temp in list(newdic2.keys()):
newdic2[temp] += 1
newdic3 = {}
for temp in newdic2.keys():
if not vis:
print(newdic1[temp],':',newdic2[temp])
newdic3[newdic1[temp]] = newdic2[temp]
if vis:
self._visualization(newdic3)
def _read_csv(self):
self.csv_data = pd.read_csv(self.out_file_name,header=None)
self.csv_data_len = len(self.csv_data)
def _visualization(self,dic):
print('可视化标注进度')
val = list(k for (v,k) in dic.items())
sca = range(len(dic))
ind = list(v for (v,k) in dic.items())
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.ylim(0,1000)
plt.bar(sca, val)
plt.xticks(sca, ind)
plt.title('各标签情况')
plt.show()
def _save_csv(self):
self.csv_data.to_csv(self.out_file_name, mode='w', index=False, header=False)
print('csv文件已保存')
def _find_img(self,vid,img):
return os.path.join('dataset',str(vid).zfill(6),str(vid).zfill(6)+'_'+str(img).zfill(6)+'.jpg')
def _get_progress(self,idx):
progress = '进度:========================================'.replace('=','>',int(40*(idx/self.csv_data_len)))
return progress+' {}/{}'.format(idx,self.csv_data_len)
def _creat_temp_img(self,imgname):
img = cv2.imread(imgname)
img = cv2.resize(img,(int(384*2),int(216*2)))
cv2.imwrite('temp.jpg',img)
if __name__ == '__main__':
Mark = mark(out_file_name=outputfile,mark_level=mark_level)
Mark.marking()
Mark.report(vis=visualization)
需要的模块主要是easygui和cv2(这是opencv-python)模块,简单地pip安装提示需要安装的模块就行
附上效果图,因为数据集还没发布所以打码了
以及标记数据的统计,因为数据集还没发布所以也打码了