针对Tensorflow版ENet,记录一下自己对代码的理解。
*非代码解读(因为水平不足),仅作为自己理解的备忘
**理解有误的地方,希望可以得到大牛的指点
一、文件夹内容(结构)
(图片截取自作者kwotsin的ENet主页)
所用ENet的文件如上图所示,由上向下分别是:
- checkpoint文件夹:作者训练camvid数据集后保存的ckpt
- dataset:作者使用的camvid数据集
- visualizations:被预测的原始动图(camvid)及预测后生成的语义分割动图,gif格式
- enet.py定义了ENet的网络结构
- get_class_weights.py计算类别权重脚本
- predict_segmentation.py单独预测图片用脚本
- preprocessing.py图片预处理脚本
- test.sh将测试时所用的参数写好,方便测试的小脚本
- test_enet.py测试数据集中测试集(test文件夹)的脚本
- train.sh将训练时所用的参数写好,方便训练的小脚本
- train_enet.py训练ENet的脚本
二、train_enet.py及相关脚本的理解
首先,对最主要的训练脚本进行研读。
1——11行,将所需的库及函数导入。
import tensorflow as tf
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
from tensorflow.python.platform import tf_logging as logging
from enet import ENet, ENet_arg_scope
from preprocessing import preprocess
from get_class_weights import ENet_weighing, median_frequency_balancing
import os
import time
import numpy as np
import matplotlib.pyplot as plt
slim = tf.contrib.slim
然后,使用tf.app.flags定义常用参数,这些定义好的参数可以在使用脚本时使用'脚本.py --参数名 参数值'的方式进行修改。而使用flags时,flags的定义方式为:flags.数据类型(变量值名称,变量值,变量描述)。
flags = tf.app.flags
#Directory arguments
# 数据集所在地址
flags.DEFINE_string('dataset_dir', './dataset', 'The dataset directory to find the train, validation and test images.')
# 保存训练数据地址
flags.DEFINE_string('logdir', './log/original', 'The log directory to save your checkpoint and event files.')
# 训练结束后是否保存一个eval_batch的图片
flags.DEFINE_boolean('save_images', True, 'Whether or not to save your images.')
# 训练过程是否将训练集及验证集合并进行训练
flags.DEFINE_boolean('combine_dataset', False, 'If True, combines the validation with the train dataset.')
#Training arguments
# 类别数
flags.DEFINE_integer('num_classes', 12, 'The number of classes to predict.')
# 训练的batch_size
flags.DEFINE_integer('batch_size', 10, 'The batch_size for training.')
# 每次验证的eval_batch_size
flags.DEFINE_integer('eval_batch_size', 25, 'The batch size used for validation.')
# 训练用图片的高
flags.DEFINE_integer('image_height', 360, "The input height of the images.")
# 训练用图片的宽
flags.DEFINE_integer('image_width', 480, "The input width of the images.")
# 训练的总轮数
flags.DEFINE_integer('num_epochs', 300, "The number of epochs to train your model.")
# 多少轮后学习率开始衰减
flags.DEFINE_integer('num_epochs_before_decay', 100, 'The number of epochs before decaying your learning rate.')
# 权重衰减
flags.DEFINE_float('weight_decay', 2e-4, "The weight decay for ENet convolution layers.")
# 学习率衰减
flags.DEFINE_float('learning_rate_decay_factor', 1e-1, 'The learning rate decay factor.')
# 初始化学习率
flags.DEFINE_float('initial_learning_rate', 5e-4, 'The initial learning rate for your training.')
# 所用的类别权重计算方式(MFB及ENet两种)
flags.DEFINE_string('weighting', "MFB", 'Choice of Median Frequency Balancing or the custom ENet class weights.')
#Architectural changes
# ENet中 初始化模块的个数(ENet论文中有描述)
flags.DEFINE_integer('num_initial_blocks', 1, 'The number of initial blocks to use in ENet.')
# ENet中 第二部分的个数(ENet论文中有描述)
flags.DEFINE_integer('stage_two_repeat', 2,