由于时长需要获知当前训练样本的信息情况,因此特意写了个脚本,方便以后的工作,在此以py-faster-rcnn工程为例,数据格式为VOC,不多说,代码中的函数功能如命名一样一看便知。
import sys
import os
import numpy as np
from matplotlib import pyplot as plt
try:
import xml.etree.cElementTree as ET
except ImportError:
import xml.etree.ElementTree as ET
PyFasterRCNNPath = '/home/dq/py-faster-rcnn'
ClsNameSet=CLASSES[1:]#数据类别集合,去除背景类
BoxLenTol=30
BoxAreaTol=BoxLenTol**2
ImSize=[960,540]
fileIdLen=6
ImExpName='.jpg'
AnotExpName='.xml'
VOCPath=os.path.join(PyFasterRCNNPath,'data/VOCdevkit2007/VOC2007')
AnotFolder=os.path.join(VOCPath,'Annotations')
TrainValTestAssignFolder=os.path.join(VOCPath,'ImageSets/Main')
TrainValTestFiles={'train':'train.txt','val':'val.txt','test':'test.txt'}
##get object annotation bndbox loc start
def GetAnnotBoxLoc(AnotPath):
#open xml
tree = ET.ElementTree(file=AnotPath)
root = tree.getroot()
ObjectSet=root.findall('object')
ObjBndBoxSet={}
for Object in ObjectSet:
ObjName=Object.find('name').text
BndBox=Object.find('bndbox')
x1 = int(BndBox.find('xmin').text)-1
y1 = int(BndBox.find('ymin').text)-1
x2 = int(BndBox.find('xmax').text)-1
y2 = int(BndBox.find('ymax').text)-1
BndBoxLoc=[x1,y1,x2,y2]
if ObjBndBoxSet.has_key(ObjName):
ObjBndBoxSet[ObjName].append(BndBoxLoc)
else:
ObjBndBoxSet[ObjName]=[BndBoxLoc]#why not ues dict(key=val)?
return ObjBndBoxSet
##get object annotation bndbox loc end
def CalSampleNum(BoxSet,BoxNumSet):
for Key,Val in BoxSet.iteritems():
if BoxNumSet.has_key(Key):
BoxNumSet[Key]=BoxNumSet[Key]+len(Val)
def CalSmallAreaSampleNum(BoxSet,SmallBoxNumSet):
for Key,Val in BoxSet.iteritems():
if SmallBoxNumSet.has_key(Key):
for Box in Val:
X1=Box[0]
Y1=Box[1]
X2=Box[2]
Y2=Box[3]
BoxArea=(X2-X1)*(Y2-Y1)
if BoxArea<BoxAreaTol:
SmallBoxNumSet[Key]=SmallBoxNumSet[Key]+1
def CalSampleWHLen(BoxSet,SampleWHLenSet):
for Key,Val in BoxSet.iteritems():
if SampleWHLenSet.has_key(Key):
for Box in Val:
X1=Box[0]
Y1=Box[1]
X2=Box[2]
Y2=Box[3]
BoxW=X2-X1+1
BoxH=Y2-Y1+1
SampleWHLenSet[Key].append([BoxW,BoxH])
def GetSampleAreaMean(SampleWHLenSet):
SampleAreaMean={}
for Key,BoxWHSet in SampleWHLenSet.iteritems():
k=0
Area=0
for BoxWH in BoxWHSet:
k+=1
Area+=BoxWH[0]*BoxWH[1]
SampleAreaMean[Key]=Area*1.0/(k+0.0001)
print '{}BoxNum={}'.format(Key,k)
return SampleAreaMean
def PlotSampleWHLenHist(SampleWHLenSet):
BinNum=20
PlotBaseNum=330
j=0
plt.figure(1)
plt.title('Sample')
for Key,Val in SampleWHLenSet.iteritems():
#print Key+' Num='+str(len(Val))
SampleWHLen=np.array(Val)
WidthLen=SampleWHLen[:,0]
HeigthLen=SampleWHLen[:,1]
j=j+1
if j<2:
AllBoxWLenSet=WidthLen#np.zeros(WidthLen.shape,dtype=np.int)
AllBoxHLenSet=HeigthLen
else:
AllBoxWLenSet=np.hstack((AllBoxWLenSet,WidthLen))
AllBoxHLenSet=np.hstack((AllBoxHLenSet,HeigthLen))
SubPlotId=PlotBaseNum+j
plt.subplot(SubPlotId)
plt.hist(WidthLen,BinNum,alpha=0.5)
XLabelStr="{} Box Width Len".format(Key)
plt.xlabel(XLabelStr)
plt.ylabel("Count")
j=j+1
SubPlotId=PlotBaseNum+j
plt.subplot(SubPlotId)
plt.hist(HeigthLen,BinNum,alpha=0.5)
XLabelStr="{} Box Heigth Len".format(Key)
plt.xlabel(XLabelStr)
plt.ylabel("Count")
j=j+1
SubPlotId=PlotBaseNum+j
plt.subplot(SubPlotId)
HDivW=HeigthLen.astype('float32')/WidthLen.astype('float32')
plt.hist(HDivW,BinNum,alpha=0.5)
XLabelStr="{} Box WDivH".format(Key)
plt.xlabel(XLabelStr)
plt.ylabel("Count")
j=j+1
SubPlotId=PlotBaseNum+j
plt.subplot(SubPlotId)
plt.hist(AllBoxWLenSet,BinNum,alpha=0.5)
XLabelStr="All Box Width Len"
plt.xlabel(XLabelStr)
plt.ylabel("Count")
j=j+1
SubPlotId=PlotBaseNum+j
plt.subplot(SubPlotId)
plt.hist(AllBoxHLenSet,BinNum,alpha=0.5)
XLabelStr="All Box Heigth Len"
plt.xlabel(XLabelStr)
plt.ylabel("Count")
j=j+1
SubPlotId=PlotBaseNum+j
plt.subplot(SubPlotId)
HDivW=AllBoxHLenSet.astype('float32')/AllBoxWLenSet.astype('float32')
plt.hist(HDivW,BinNum,alpha=0.5)
XLabelStr="All Box WDivH"
plt.xlabel(XLabelStr)
plt.ylabel("Count")
plt.show()
##################main##########################
def GetTotalSampleInfoMain():
TotalSampleNum={}
SmallBoxNumSet={}
MeanSampleNum={}
BigAreaSampleNum={}
SampleWHLenSet={}
for ClassName in ClsNameSet:
TotalSampleNum[ClassName]=0#dict.fromkeys(ClsNameSet, 0)
SmallBoxNumSet[ClassName]=0#dict.fromkeys(ClsNameSet, 0)
MeanSampleNum[ClassName]=0#dict.fromkeys(ClsNameSet, 0)
BigAreaSampleNum[ClassName]=0#dict.fromkeys(ClsNameSet, 0)
SampleWHLenSet[ClassName]=[] ##append 2018.10.22
AnotFileSet=os.listdir(AnotFolder)
AnotFileNum=len(AnotFileSet)
for AnotName in AnotFileSet:
AnotPath=os.path.join(AnotFolder,AnotName)
AnotBoxSet=GetAnnotBoxLoc(AnotPath)
CalSampleNum(AnotBoxSet,TotalSampleNum)
CalSmallAreaSampleNum(AnotBoxSet,SmallBoxNumSet)
CalSampleWHLen(AnotBoxSet,SampleWHLenSet)
for Key,Val in TotalSampleNum.iteritems():
if MeanSampleNum.has_key(Key):
MeanSampleNum[Key]=round(Val*1.0/AnotFileNum,2)
for Key,Val in TotalSampleNum.iteritems():
if BigAreaSampleNum.has_key(Key):
BigAreaSampleNum[Key]=TotalSampleNum[Key]-SmallBoxNumSet[Key]
print 'ImNum='+str(AnotFileNum)
print 'TotalSampleNum='+str(TotalSampleNum)
print 'MeanSampleNum='+str(MeanSampleNum)
print 'BoxAreaTol='+str(BoxLenTol)+'*'+str(BoxLenTol)
print 'SmallAreaSampleNum='+str(SmallBoxNumSet)
print 'BigAreaSampleNum='+str(BigAreaSampleNum)
#PlotSampleWHLenHist(SampleWHLenSet)
SampleAreaMean=GetSampleAreaMean(SampleWHLenSet)
#print 'SampleAreaMean='+str(SampleAreaMean)
def GetTrainValTestSampleInfo(SampleNumSet,ImIdFilePath):
with open(ImIdFilePath,'r') as FId:
k=0
TxtList=FId.readlines()
for LineStr in TxtList:
PureStr=LineStr.strip()
AnotFileName=PureStr+AnotExpName
AnotFilePath=os.path.join(AnotFolder,AnotFileName)
AnotBoxSet=GetAnnotBoxLoc(AnotFilePath)
CalSampleNum(AnotBoxSet,SampleNumSet)
k=k+1
FileName=os.path.basename(ImIdFilePath)
print FileName +' ImageNum='+str(k)+';',
return k
def GetTrainValTestSampleInfoMain():
TrainValSampleNumSet={}
for ClassName in ClsNameSet:
TrainValSampleNumSet[ClassName]=0
print '\n'
TrainValTestSampleInfoSet={}
for Key,FileName in TrainValTestFiles.iteritems():
ImIdFilePath=os.path.join(TrainValTestAssignFolder,FileName)
SampleNumSet={}
for ClassName in ClsNameSet:
SampleNumSet[ClassName]=0
ImNum=GetTrainValTestSampleInfo(SampleNumSet,ImIdFilePath)
print FileName[:-4]+'SampleNumSet='+str(SampleNumSet)
print '\n'
CurSetInfo={}
CurSetInfo['ImNum']=ImNum
CurSetInfo['SampleNumSet']=SampleNumSet
TrainValTestSampleInfoSet[Key]=CurSetInfo
return TrainValTestSampleInfoSet
if __name__ == "__main__":
GetTotalSampleInfoMain()
TrainValTestSampleInfoSet=GetTrainValTestSampleInfoMain()