个人代码阅读笔记。
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen and Zheqi He
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from model.config import cfg
import roi_data_layer.roidb as rdl_roidb
from roi_data_layer.layer import RoIDataLayer
from utils.timer import Timer
try:
import cPickle as pickle
except ImportError:
import pickle
import numpy as np
import os
import sys
import glob
import time
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
'''
这是训练的第二层入口,第一层是train_faster_rcnn.sh
'''
class SolverWrapper(object):
"""
A wrapper class for the training process
训练过程的封装
"""
#实例初始化
def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
self.net = network
self.imdb = imdb#image database
self.roidb = roidb#roi database
self.valroidb = valroidb#valid roi database
self.output_dir = output_dir#输出路径
self.tbdir = tbdir
# Simply put '_val' at the end to save the summaries from the validation set
self.tbvaldir = tbdir + '_val'#?
if not os.path.exists(self.tbvaldir):
os.makedirs(self.tbvaldir)
self.pretrained_model = pretrained_model#pre训练模型
def snapshot(self, sess, iter):#保存参数snapshot,多久保存一次tf的会话
net = self.net
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
# Store the model snapshot
#TRAIN.SNAPSHOT_PREFIX为名字前缀‘res101_faster_rcnn'’
filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
filename = os.path.join(self.output_dir, filename)
self.saver.save(sess, filename)
print('Wrote snapshot to: {:s}'.format(filename))
# Also store some meta information, random state, etc.
#可以看出来ckpt才是主要的权重,pkl是保存了一些原始信息和随机状态
nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
nfilename = os.path.join(self.output_dir, nfilename)
# current state of numpy random
#当前的numpy.random状态(随机种子什么的
st0 = np.random.get_state()
# current position in the database
#当前载入数据库
cur = self.data_layer._cur
# current shuffled indexes of the database
#当前打乱的数据索引
perm = self.data_layer._perm
# current position in the validation database
#当前数据在validation database中的位置
cur_val = self.data_layer_val._cur
# current shuffled indexes of the validation database
#当前打乱的验证数据集索引
perm_val = self.data_layer_val._perm
# Dump the meta info
with open(nfilename, 'wb') as fid:#打开文件
pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)
return filename, nfilename
def from_snapshot(self, sess, sfile, nfile):#载入快照
print('Restoring mode