# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen and Zheqi He
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from model.config import cfg
import roi_data_layer.roidb as rdl_roidb
from roi_data_layer.layer import RoIDataLayer
from utils.timer import Timer
try:
import cPickle as pickle
except ImportError:
import pickle
import numpy as np
import os
import sys
import glob
import time
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
# Slover的封装类,包含和训练有关的属性和方法
class SolverWrapper(object):
"""
A wrapper class for the training process
"""
def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
self.net = network # network类的实例
self.imdb = imdb # imdb类的实例
self.roidb = roidb # roidb字典
self.valroidb = valroidb # 验证roidb字典
self.output_dir = output_dir # 模型保存路径
self.tbdir = tbdir # tensorboard保存路径
# Simply put '_val' at the end to save the summaries from the validation set
self.tbvaldir = tbdir + '_val' # 验证过程的tensorboard保存路径
if not os.path.exists(self.tbvaldir):
os.makedirs(self.tbvaldir)
self.pretrained_model = pretrained_model # 预训练权重的路径
# 保存快照,包括模型权重的ckpt文件和训练参数的pkl文件
def snapshot(self, sess, iter):
net = self.net
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
# Store the model snapshot
# 保存模型权重 例:shufflenetv2_faster_rcnn_iter_10000.cpkt等三个文件
# SNAPSHOT_PREFIX在yml文件中配置 例:'shufflenetv2_faster_rcnn'
filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
filename = os.path.join(self.output_dir, filename)
self.saver.save(sess, filename)
print('Wrote snapshot to: {:s}'.format(filename))
# Also store some meta information, random state, etc.
# 保存随机数状态等训练过程参数
nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
nfilename = os.path.join(self.output_dir, nfilename)
# current state of numpy random
# 保存随机数种子
st0 = np.random.get_state()
# current position in the database
# 保存当前图片序号
cur = self.data_layer._cur
# current shuffled indexes of the database
# 保存打乱后的图片序号列表
perm = self.data_layer._perm
# current position in the validation database
# 验证过程,同上
cur_val = self.data_layer_val._cur
# current shuffled indexes of the validation database
perm_val = self.data_layer_val._perm
# Dump the meta info
# 写入 例:shufflenetv2_faster_rcnn_iter_10000.pkl
with open(nfilename, 'wb') as fid:
pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)
return filename, nfilename
# 载入快照
def from_snapshot(self, sess, sfile, nfile):
print('Restoring model snapshots from {:s}'.format(sfile))
self.saver.restore(sess, sfile)
print('Restored.')
# Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
# tried my best to find the random states so that it can be recovered exactly
# However the Tensorflow state
Faster Rcnn 代码解读之 train_val.py
最新推荐文章于 2021-09-27 21:21:59 发布
本文详细解读了Faster R-CNN在训练过程中的核心代码train_val.py,涵盖了目标检测模型的训练流程,包括数据预处理、网络架构、损失计算和反向传播等关键步骤。
摘要由CSDN通过智能技术生成