win10 caffe python Faster-RCNN训练自己数据集(转)

一、制作数据集

1. 关于训练的图片

不论你是网上找的图片或者你用别人的数据集,记住一点你的图片不能太小,width和height最好不要小于150。需要是jpeg的图片。

2.制作xml文件

1)LabelImg

如果你的数据集比较小的话,你可以考虑用LabelImg手工打框https://github.com/tzutalin/labelImg。关于labelimg的具体使用方法我在这就不详细说明了,大家可以去网上找一下。labelimg生成的xml直接就能给frcnn训练使用。

2)自己制作xml

如果你的数据集比较小的话,你还可以考虑用上面的方法手工打框。如果你的数据集有1w+你就可以考虑自动生成xml文件。网上有些资料基本用的是matlab坐标生成xml。我给出一段python的生成xml的代码

[python] view plain copy
 
 
print?
  1. <span style="font-size:14px;">  
  2. def write_xml(bbox,w,h,iter):  
  3.     ''''' 
  4.     bbox为你保存的当前图片的类别的信息和对应坐标的dict 
  5.     w,h为你当前保存图片的width和height 
  6.     iter为你图片的序号 
  7.     '''  
  8.     root=Element("annotation")  
  9.     folder=SubElement(root,"folder")#1  
  10.     folder.text="JPEGImages"  
  11.     filename=SubElement(root,"filename")#1  
  12.     filename.text=iter  
  13.     path=SubElement(root,"path")#1  
  14.     path.text='D:\\py-faster-rcnn\\data\\VOCdevkit2007\\VOC2007\\JPEGImages'+'\\'+iter+'.jpg'#把这个路径改为你的路径就行  
  15.     source=SubElement(root,"source")#1  
  16.     database=SubElement(source,"database")#2  
  17.     database.text="Unknown"  
  18.     size=SubElement(root,"size")#1  
  19.     width=SubElement(size,"width")#2  
  20.     height=SubElement(size,"height")#2  
  21.     depth=SubElement(size,"depth")#2  
  22.     width.text=str(w)  
  23.     height.text=str(h)  
  24.     depth.text='3'  
  25.     segmented=SubElement(root,"segmented")#1  
  26.     segmented.text='0'  
  27.     for i in bbox:  
  28.         object=SubElement(root,"object")#1  
  29.         name=SubElement(object,"name")#2  
  30.         name.text=i['cls']  
  31.         pose=SubElement(object,"pose")#2  
  32.         pose.text="Unspecified"  
  33.         truncated=SubElement(object,"truncated")#2  
  34.         truncated.text='0'  
  35.         difficult=SubElement(object,"difficult")#2  
  36.         difficult.text='0'  
  37.         bndbox=SubElement(object,"bndbox")#2  
  38.         xmin=SubElement(bndbox,"xmin")#3  
  39.         ymin=SubElement(bndbox,"ymin")#3  
  40.         xmax=SubElement(bndbox,"xmax")#3  
  41.         ymax=SubElement(bndbox,"ymax")#3  
  42.         xmin.text=str(i['xmin'])  
  43.         ymin.text=str(i['ymin'])  
  44.         xmax.text=str(i['xmax'])  
  45.         ymax.text=str(i['ymax'])  
  46.     xml=tostring(root,pretty_print=True)  
  47.     file=open('D:/py-faster-rcnn/data/VOCdevkit2007/VOC2007/Annotations/'+iter+'.xml','w+')#这里的路径也改为你自己的路径  
  48.     file.write(xml)</span>  
def write_xml(bbox,w,h,iter):
    '''
    bbox为你保存的当前图片的类别的信息和对应坐标的dict
    w,h为你当前保存图片的width和height
    iter为你图片的序号
    '''
    root=Element("annotation")
    folder=SubElement(root,"folder")#1
    folder.text="JPEGImages"
    filename=SubElement(root,"filename")#1
    filename.text=iter
    path=SubElement(root,"path")#1
    path.text='D:\\py-faster-rcnn\\data\\VOCdevkit2007\\VOC2007\\JPEGImages'+'\\'+iter+'.jpg'#把这个路径改为你的路径就行
    source=SubElement(root,"source")#1
    database=SubElement(source,"database")#2
    database.text="Unknown"
    size=SubElement(root,"size")#1
    width=SubElement(size,"width")#2
    height=SubElement(size,"height")#2
    depth=SubElement(size,"depth")#2
    width.text=str(w)
    height.text=str(h)
    depth.text='3'
    segmented=SubElement(root,"segmented")#1
    segmented.text='0'
    for i in bbox:
        object=SubElement(root,"object")#1
        name=SubElement(object,"name")#2
        name.text=i['cls']
        pose=SubElement(object,"pose")#2
        pose.text="Unspecified"
        truncated=SubElement(object,"truncated")#2
        truncated.text='0'
        difficult=SubElement(object,"difficult")#2
        difficult.text='0'
        bndbox=SubElement(object,"bndbox")#2
        xmin=SubElement(bndbox,"xmin")#3
        ymin=SubElement(bndbox,"ymin")#3
        xmax=SubElement(bndbox,"xmax")#3
        ymax=SubElement(bndbox,"ymax")#3
        xmin.text=str(i['xmin'])
        ymin.text=str(i['ymin'])
        xmax.text=str(i['xmax'])
        ymax.text=str(i['ymax'])
    xml=tostring(root,pretty_print=True)
    file=open('D:/py-faster-rcnn/data/VOCdevkit2007/VOC2007/Annotations/'+iter+'.xml','w+')#这里的路径也改为你自己的路径
    file.write(xml)


 

3.制作训练、测试、验证集

这个网上可以参考的资料比较多,我直接copy一个小咸鱼的用matlab的代码

我建议train和trainval的部分占得比例可以更大一点

[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;">%%    
  2. %该代码根据已生成的xml,制作VOC2007数据集中的trainval.txt;train.txt;test.txt和val.txt    
  3. %trainval占总数据集的50%,test占总数据集的50%;train占trainval的50%,val占trainval的50%;    
  4. %上面所占百分比可根据自己的数据集修改,如果数据集比较少,test和val可少一些    
  5. %%    
  6. %注意修改下面四个值    
  7. xmlfilepath='E:\Annotations';    
  8. txtsavepath='E:\ImageSets\Main\';    
  9. trainval_percent=0.5;%trainval占整个数据集的百分比,剩下部分就是test所占百分比    
  10. train_percent=0.5;%train占trainval的百分比,剩下部分就是val所占百分比    
  11.     
  12.     
  13. %%    
  14. xmlfile=dir(xmlfilepath);    
  15. numOfxml=length(xmlfile)-2;%减去.和..  总的数据集大小    
  16.     
  17.     
  18. trainval=sort(randperm(numOfxml,floor(numOfxml*trainval_percent)));    
  19. test=sort(setdiff(1:numOfxml,trainval));    
  20.     
  21.     
  22. trainvalsize=length(trainval);%trainval的大小    
  23. train=sort(trainval(randperm(trainvalsize,floor(trainvalsize*train_percent))));    
  24. val=sort(setdiff(trainval,train));    
  25.     
  26.     
  27. ftrainval=fopen([txtsavepath 'trainval.txt'],'w');    
  28. ftest=fopen([txtsavepath 'test.txt'],'w');    
  29. ftrain=fopen([txtsavepath 'train.txt'],'w');    
  30. fval=fopen([txtsavepath 'val.txt'],'w');    
  31.     
  32.     
  33. for i=1:numOfxml    
  34.     if ismember(i,trainval)    
  35.         fprintf(ftrainval,'%s\n',xmlfile(i+2).name(1:end-4));    
  36.         if ismember(i,train)    
  37.             fprintf(ftrain,'%s\n',xmlfile(i+2).name(1:end-4));    
  38.         else    
  39.             fprintf(fval,'%s\n',xmlfile(i+2).name(1:end-4));    
  40.         end    
  41.     else    
  42.         fprintf(ftest,'%s\n',xmlfile(i+2).name(1:end-4));    
  43.     end    
  44. end    
  45. fclose(ftrainval);    
  46. fclose(ftrain);    
  47. fclose(fval);    
  48. fclose(ftest);</span>  
%%  
%该代码根据已生成的xml,制作VOC2007数据集中的trainval.txt;train.txt;test.txt和val.txt  
%trainval占总数据集的50%,test占总数据集的50%;train占trainval的50%,val占trainval的50%;  
%上面所占百分比可根据自己的数据集修改,如果数据集比较少,test和val可少一些  
%%  
%注意修改下面四个值  
xmlfilepath='E:\Annotations';  
txtsavepath='E:\ImageSets\Main\';  
trainval_percent=0.5;%trainval占整个数据集的百分比,剩下部分就是test所占百分比  
train_percent=0.5;%train占trainval的百分比,剩下部分就是val所占百分比  
  
  
%%  
xmlfile=dir(xmlfilepath);  
numOfxml=length(xmlfile)-2;%减去.和..  总的数据集大小  
  
  
trainval=sort(randperm(numOfxml,floor(numOfxml*trainval_percent)));  
test=sort(setdiff(1:numOfxml,trainval));  
  
  
trainvalsize=length(trainval);%trainval的大小  
train=sort(trainval(randperm(trainvalsize,floor(trainvalsize*train_percent))));  
val=sort(setdiff(trainval,train));  
  
  
ftrainval=fopen([txtsavepath 'trainval.txt'],'w');  
ftest=fopen([txtsavepath 'test.txt'],'w');  
ftrain=fopen([txtsavepath 'train.txt'],'w');  
fval=fopen([txtsavepath 'val.txt'],'w');  
  
  
for i=1:numOfxml  
    if ismember(i,trainval)  
        fprintf(ftrainval,'%s\n',xmlfile(i+2).name(1:end-4));  
        if ismember(i,train)  
            fprintf(ftrain,'%s\n',xmlfile(i+2).name(1:end-4));  
        else  
            fprintf(fval,'%s\n',xmlfile(i+2).name(1:end-4));  
        end  
    else  
        fprintf(ftest,'%s\n',xmlfile(i+2).name(1:end-4));  
    end  
end  
fclose(ftrainval);  
fclose(ftrain);  
fclose(fval);  
fclose(ftest);


 

4.文件保存路径

jpg,txt,xml分别保存到data\VOCdevkit2007\VOC2007\下的JPEGImages、ImageSets\Main、Annotations文件夹

二、根据自己的数据集修改文件

1.模型配置文件

我用end2end的方式训练,这里我用vgg_cnn_m_1024为例说明。所以我们先打开models\pascal_voc\VGG_CNN_M_1024\faster_rcnn_end2end\train.prototxt,有4处需要修改

 

[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;">layer {  
  2.   name: 'input-data'  
  3.   type: 'Python'  
  4.   top: 'data'  
  5.   top: 'im_info'  
  6.   top: 'gt_boxes'  
  7.   python_param {  
  8.     module: 'roi_data_layer.layer'  
  9.     layer: 'RoIDataLayer'  
  10.     param_str: "'num_classes': 3" #这里改为你训练类别数+1  
  11.   }  
  12. }</span>  
layer {
  name: 'input-data'
  type: 'Python'
  top: 'data'
  top: 'im_info'
  top: 'gt_boxes'
  python_param {
    module: 'roi_data_layer.layer'
    layer: 'RoIDataLayer'
    param_str: "'num_classes': 3" #这里改为你训练类别数+1
  }
}

[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;">layer {  
  2.   name: 'roi-data'  
  3.   type: 'Python'  
  4.   bottom: 'rpn_rois'  
  5.   bottom: 'gt_boxes'  
  6.   top: 'rois'  
  7.   top: 'labels'  
  8.   top: 'bbox_targets'  
  9.   top: 'bbox_inside_weights'  
  10.   top: 'bbox_outside_weights'  
  11.   python_param {  
  12.     module: 'rpn.proposal_target_layer'  
  13.     layer: 'ProposalTargetLayer'  
  14.     param_str: "'num_classes': 3" #这里改为你训练类别数+1  
  15.   }  
  16. }</span>  
layer {
  name: 'roi-data'
  type: 'Python'
  bottom: 'rpn_rois'
  bottom: 'gt_boxes'
  top: 'rois'
  top: 'labels'
  top: 'bbox_targets'
  top: 'bbox_inside_weights'
  top: 'bbox_outside_weights'
  python_param {
    module: 'rpn.proposal_target_layer'
    layer: 'ProposalTargetLayer'
    param_str: "'num_classes': 3" #这里改为你训练类别数+1
  }
}
[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;">layer {  
  2.   name: "cls_score"  
  3.   type: "InnerProduct"  
  4.   bottom: "fc7"  
  5.   top: "cls_score"  
  6.   param {  
  7.     lr_mult: 1  
  8.   }  
  9.   param {  
  10.     lr_mult: 2  
  11.   }  
  12.   inner_product_param {  
  13.     num_output: 3  #这里改为你训练类别数+1  
  14.     weight_filler {  
  15.       type: "gaussian"  
  16.       std: 0.01  
  17.     }  
  18.     bias_filler {  
  19.       type: "constant"  
  20.       value: 0  
  21.     }  
  22.   }  
  23. }  
  24. layer {  
  25.   name: "bbox_pred"  
  26.   type: "InnerProduct"  
  27.   bottom: "fc7"  
  28.   top: "bbox_pred"  
  29.   param {  
  30.     lr_mult: 1  
  31.   }  
  32.   param {  
  33.     lr_mult: 2  
  34.   }  
  35.   inner_product_param {  
  36.     num_output: 12  #这里改为你的(类别数+1)*4  
  37.     weight_filler {  
  38.       type: "gaussian"  
  39.       std: 0.001  
  40.     }  
  41.     bias_filler {  
  42.       type: "constant"  
  43.       value: 0  
  44.     }  
  45.   }  
  46. }</span>  
layer {
  name: "cls_score"
  type: "InnerProduct"
  bottom: "fc7"
  top: "cls_score"
  param {
    lr_mult: 1
  }
  param {
    lr_mult: 2
  }
  inner_product_param {
    num_output: 3  #这里改为你训练类别数+1
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "bbox_pred"
  type: "InnerProduct"
  bottom: "fc7"
  top: "bbox_pred"
  param {
    lr_mult: 1
  }
  param {
    lr_mult: 2
  }
  inner_product_param {
    num_output: 12  #这里改为你的(类别数+1)*4
    weight_filler {
      type: "gaussian"
      std: 0.001
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
然后我们修改models\pascal_voc\VGG_CNN_M_1024\faster_rcnn_end2end\test.prototxt。

[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;">layer {  
  2.   name: "relu7"  
  3.   type: "ReLU"  
  4.   bottom: "fc7"  
  5.   top: "fc7"  
  6. }  
  7. layer {  
  8.   name: "cls_score"  
  9.   type: "InnerProduct"  
  10.   bottom: "fc7"  
  11.   top: "cls_score"  
  12.   param {  
  13.     lr_mult: 1  
  14.     decay_mult: 1  
  15.   }  
  16.   param {  
  17.     lr_mult: 2  
  18.     decay_mult: 0  
  19.   }  
  20.   inner_product_param {  
  21.     num_output: 3 </span><span style="font-size:14px;"> #这里改为你训练类别数+1</span><span style="font-size:14px;">  
  22. </span><span style="font-size:14px;"></span>  
layer {
  name: "relu7"
  type: "ReLU"
  bottom: "fc7"
  top: "fc7"
}
layer {
  name: "cls_score"
  type: "InnerProduct"
  bottom: "fc7"
  top: "cls_score"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 3  #这里改为你训练类别数+1
[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;">    weight_filler {  
  2.       type: "gaussian"  
  3.       std: 0.01  
  4.     }  
  5.     bias_filler {  
  6.       type: "constant"  
  7.       value: 0  
  8.     }  
  9.   }  
  10. }  
  11. layer {  
  12.   name: "bbox_pred"  
  13.   type: "InnerProduct"  
  14.   bottom: "fc7"  
  15.   top: "bbox_pred"  
  16.   param {  
  17.     lr_mult: 1  
  18.     decay_mult: 1  
  19.   }  
  20.   param {  
  21.     lr_mult: 2  
  22.     decay_mult: 0  
  23.   }  
  24.   inner_product_param {  
  25.     num_output: 12 </span><span style="font-size:14px;"> #这里改为你的(类别数+1)*4</span><span style="font-size:14px;">  
  26. </span>  
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "bbox_pred"
  type: "InnerProduct"
  bottom: "fc7"
  top: "bbox_pred"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  inner_product_param {
    num_output: 12  #这里改为你的(类别数+1)*4
[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;">    weight_filler {  
  2.       type: "gaussian"  
  3.       std: 0.001  
  4.     }  
  5.     bias_filler {  
  6.       type: "constant"  
  7.       value: 0  
  8.     }  
  9.   }  
  10. }</span>  
    weight_filler {
      type: "gaussian"
      std: 0.001
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
 
  

另外在 solver里可以调训练的学习率等参数,在这篇文章里不做说明

==================以下修改lib中的文件==================

2.修改imdb.py

 

[python] view plain copy
 
 
print?
  1. <span style="font-size:14px;">    def append_flipped_images(self):    
  2.         num_images = self.num_images    
  3.         widths = [PIL.Image.open(self.image_path_at(i)).size[0]    
  4.                   for i in xrange(num_images)]    
  5.         for i in xrange(num_images):    
  6.             boxes = self.roidb[i]['boxes'].copy()    
  7.             oldx1 = boxes[:, 0].copy()    
  8.             oldx2 = boxes[:, 2].copy()    
  9.             boxes[:, 0] = widths[i] - oldx2 - 1    
  10.             boxes[:, 2] = widths[i] - oldx1 - 1  
  11.             for b in range(len(boxes)):  
  12.                 if boxes[b][2]< boxes[b][0]:  
  13.                     boxes[b][0] = 0           
  14.             assert (boxes[:, 2] >= boxes[:, 0]).all()    
  15.             entry = {'boxes' : boxes,    
  16.                      'gt_overlaps' : self.roidb[i]['gt_overlaps'],    
  17.                      'gt_classes' : self.roidb[i]['gt_classes'],    
  18.                      'flipped' : True}    
  19.             self.roidb.append(entry)    
  20.         self._image_index = self._image_index * 2 </span>  
    def append_flipped_images(self):  
        num_images = self.num_images  
        widths = [PIL.Image.open(self.image_path_at(i)).size[0]  
                  for i in xrange(num_images)]  
        for i in xrange(num_images):  
            boxes = self.roidb[i]['boxes'].copy()  
            oldx1 = boxes[:, 0].copy()  
            oldx2 = boxes[:, 2].copy()  
            boxes[:, 0] = widths[i] - oldx2 - 1  
            boxes[:, 2] = widths[i] - oldx1 - 1
            for b in range(len(boxes)):
                if boxes[b][2]< boxes[b][0]:
                    boxes[b][0] = 0			
            assert (boxes[:, 2] >= boxes[:, 0]).all()  
            entry = {'boxes' : boxes,  
                     'gt_overlaps' : self.roidb[i]['gt_overlaps'],  
                     'gt_classes' : self.roidb[i]['gt_classes'],  
                     'flipped' : True}  
            self.roidb.append(entry)  
        self._image_index = self._image_index * 2 
找到这个函数,并修改为如上

 

3、修改rpn层的5个文件

在如下目录下,将文件中param_str_全部改为param_str

 

4、修改config.py

将训练和测试的proposals改为gt

 

[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;"># Train using these proposals  
  2. __C.TRAIN.PROPOSAL_METHOD = 'gt'  
  3. # Test using these proposals  
  4. __C.TEST.PROPOSAL_METHOD = 'gt</span>  
# Train using these proposals
__C.TRAIN.PROPOSAL_METHOD = 'gt'
# Test using these proposals
__C.TEST.PROPOSAL_METHOD = 'gt

 

5、修改pascal_voc.py

因为我们使用VOC来训练,所以这个是我们主要修改的训练的文件。

 

[plain] view plain copy
 
 
print?
  1. <span style="font-size:14px;"> def __init__(self, image_set, year, devkit_path=None):  
  2.         imdb.__init__(self, 'voc_' + year + '_' + image_set)  
  3.         self._year = year  
  4.         self._image_set = image_set  
  5.         self._devkit_path = self._get_default_path() if devkit_path is None \  
  6.                             else devkit_path  
  7.         self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)  
  8.         self._classes = ('__background__', # always index 0  
  9.                             'cn-character','seal')  
  10.         self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))  
  11.         self._image_ext = '.jpg'  
  12.         self._image_index = self._load_image_set_index()  
  13.         # Default to roidb handler  
  14.         self._roidb_handler = self.selective_search_roidb  
  15.         self._salt = str(uuid.uuid4())  
  16.         self._comp_id = 'comp4'</span>  
 def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
                            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
        self._classes = ('__background__', # always index 0
                            'cn-character','seal')
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        self._roidb_handler = self.selective_search_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

在self.classes这里,'__background__'使我们的背景类,不要动他。下面的改为你自己标签的内容。

 

修改以下2段内容。否则你的test部分一定会出问题。

 

[python] view plain copy
 
 
print?
  1. def _get_voc_results_file_template(self):  
  2.        # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt  
  3.        filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'  
  4.        path = os.path.join(  
  5.            self._devkit_path,  
  6.            'VOC' + self._year,  
  7.         ImageSets,  
  8.            'Main',  
  9.            '{}' + '_test.txt')  
  10.        return path  
 def _get_voc_results_file_template(self):
        # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
        filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
        path = os.path.join(
            self._devkit_path,
            'VOC' + self._year,
			ImageSets,
            'Main',
            '{}' + '_test.txt')
        return path
[python] view plain copy
 
 
print?
  1. def _write_voc_results_file(self, all_boxes):  
  2.        for cls_ind, cls in enumerate(self.classes):  
  3.            if cls == '__background__':  
  4.                continue  
  5.            print 'Writing {} VOC results file'.format(cls)  
  6.            filename = self._get_voc_results_file_template().format(cls)  
  7.            with open(filename, 'w+') as f:  
  8.                for im_ind, index in enumerate(self.image_index):  
  9.                    dets = all_boxes[cls_ind][im_ind]  
  10.                    if dets == []:  
  11.                        continue  
  12.                    # the VOCdevkit expects 1-based indices  
  13.                    for k in xrange(dets.shape[0]):  
  14.                        f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.  
  15.                                format(index, dets[k, -1],  
  16.                                       dets[k, 0] + 1, dets[k, 1] + 1,  
  17.                                       dets[k, 2] + 1, dets[k, 3] + 1))  
 def _write_voc_results_file(self, all_boxes):
        for cls_ind, cls in enumerate(self.classes):
            if cls == '__background__':
                continue
            print 'Writing {} VOC results file'.format(cls)
            filename = self._get_voc_results_file_template().format(cls)
            with open(filename, 'w+') as f:
                for im_ind, index in enumerate(self.image_index):
                    dets = all_boxes[cls_ind][im_ind]
                    if dets == []:
                        continue
                    # the VOCdevkit expects 1-based indices
                    for k in xrange(dets.shape[0]):
                        f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                                format(index, dets[k, -1],
                                       dets[k, 0] + 1, dets[k, 1] + 1,
                                       dets[k, 2] + 1, dets[k, 3] + 1))

 

三、end2end训练

1、删除缓存文件

每次训练前将data\cache 和 data\VOCdevkit2007\annotations_cache中的文件删除。

2、开始训练

在py-faster-rcnn的根目录下打开git bash输入

[plain] view plain copy
 
 
print?
  1. <span style="font-size:18px;">./experiments/scripts/faster_rcnn_end2end.sh 0 VGG_CNN_M_1024 pascal_voc</span>  
./experiments/scripts/faster_rcnn_end2end.sh 0 VGG_CNN_M_1024 pascal_voc

当然你可以去experiments\scripts\faster_rcnn_end2end.sh中调自己的训练的一些参数,也可以中VGG16、ZF模型去训练。我这里就用默认给的参数说明。

 

出现了这种东西的话,那就是训练成功了。用vgg1024的话还是很快的,还是要看你的配置,我用1080ti的话也就85min左右。我就没有让他训练结束了。

四、测试

 

训练完成之后,将output中的最终模型拷贝到data/faster_rcnn_models,修改tools下的demo.py,我是使用VGG_CNN_M_1024这个中型网络,不是默认的ZF,所以要改的地方挺多
1. 修改class

1
2
3
4
5
6
7
8
9
10
11
12
CLASSES = ('__background__',
'Blouse', 'Sweatpants', 'Cardigan', 'Button-Down',
'Cutoffs', 'Chinos', 'Top', 'Anorak', 'Kimono',
'Tank', 'Robe', 'Parka', 'Jodhpurs',
'Halter', 'Shorts', 'Caftan','Turtleneck',
'Leggings', 'Joggers', 'Hoodie', 'Culottes',
'Sweater', 'Flannel', 'Jeggings', 'Blazer',
'Onesie', 'Coat', 'Henley', 'Jacket',
'Trunks', 'Gauchos', 'Sweatshorts', 'Romper',
'Jersey', 'Bomber', 'Sarong', 'Dress','Jeans',
'Tee', 'Coverup', 'Capris', 'Kaftan','Peacoat',
'Poncho', 'Skirt', 'Jumpsuit')

 

2. 增加你自己训练的模型

1
2
3
4
5
NETS = {'vgg16': ('VGG16',
'VGG16_faster_rcnn_final.caffemodel'),
'zf': ('ZF',
'ZF_faster_rcnn_final.caffemodel'),
'myvgg1024':('VGG_CNN_M_1024','vgg_cnn_m_1024_faster_rcnn_iter_70000.caffemodel')}

 

3. 修改prototxt,如果你用的是ZF,就不用改了

1
2
prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
'faster_rcnn_end2end', 'test.prototxt')

 

 

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals

    args = parse_args()

    prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
                            'faster_rcnn_end2end', 'test.prototxt')
    caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',
                              NETS[args.demo_net][1])

    if not os.path.isfile(caffemodel):
        raise IOError(('{:s} not found.\nDid you run ./data/script/'
                       'fetch_faster_rcnn_models.sh?').format(caffemodel))

    if args.cpu_mode:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(args.gpu_id)
        cfg.GPU_ID = args.gpu_id
    net = caffe.Net(prototxt, caffemodel, caffe.TEST)

    print '\n\nLoaded network {:s}'.format(caffemodel)

    # Warmup on a dummy image
    im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
    for i in xrange(2):
        _, _= im_detect(net, im)

    im_names = ['f1.jpg','f8.jpg','f7.jpg','f6.jpg','f5.jpg','f4.jpg','f3.jpg','f2.jpg',]
    for im_name in im_names:
        print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
        print 'Demo for data/demo/{}'.format(im_name)
        demo(net, im_name)

    plt.show()

 

 

在这个部分,将你要测试的图片写在im_names里,并把图片放在data\demo这个文件夹下。

4. 开始检测
执行 ./tools/demo.py –net myvgg1024
假如不想那么麻烦输入参数,可以在demo的parse_args()里修改默认参数
parser.add_argument(‘–net’, dest=’demo_net’, help=’Network to use [myvgg1024]’,
choices=NETS.keys(), default=’myvgg1024’)
这样只需要输入 ./tools/demo.py 就可以了

 



 

转载于:https://www.cnblogs.com/bile/p/9110954.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值