TF版FasterRCNN:train_val.py代码解读笔记

本文详细解读了使用TensorFlow实现的Faster R-CNN模型训练与验证过程的代码,涵盖了关键步骤和核心算法的理解。
摘要由CSDN通过智能技术生成

个人代码阅读笔记。

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen and Zheqi He
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from model.config import cfg
import roi_data_layer.roidb as rdl_roidb
from roi_data_layer.layer import RoIDataLayer
from utils.timer import Timer
try:
  import cPickle as pickle
except ImportError:
  import pickle
import numpy as np
import os
import sys
import glob
import time

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
'''
这是训练的第二层入口,第一层是train_faster_rcnn.sh
'''
class SolverWrapper(object):
  """
    A wrapper class for the training process
	训练过程的封装
  """
  #实例初始化
  def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network
    self.imdb = imdb#image database
    self.roidb = roidb#roi database
    self.valroidb = valroidb#valid roi database
    self.output_dir = output_dir#输出路径
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'#?
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model#pre训练模型

  def snapshot(self, sess, iter):#保存参数snapshot,多久保存一次tf的会话
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
	#TRAIN.SNAPSHOT_PREFIX为名字前缀‘res101_faster_rcnn'’
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
	#可以看出来ckpt才是主要的权重,pkl是保存了一些原始信息和随机状态
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
	#当前的numpy.random状态(随机种子什么的
    st0 = np.random.get_state()
    # current position in the database
	#当前载入数据库
    cur = self.data_layer._cur
    # current shuffled indexes of the database
	#当前打乱的数据索引
    perm = self.data_layer._perm
    # current position in the validation database
	#当前数据在validation database中的位置
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
	#当前打乱的验证数据集索引
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:#打开文件
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def from_snapshot(self, sess, sfile, nfile):#载入快照
    print('Restoring mode
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值