【学习笔记】Tensorflow-ENet代码学习(一)

针对Tensorflow版ENet,记录一下自己对代码的理解。

*非代码解读(因为水平不足),仅作为自己理解的备忘

**理解有误的地方,希望可以得到大牛的指点


一、文件夹内容(结构)

                                                                                  (图片截取自作者kwotsin的ENet主页)

所用ENet的文件如上图所示,由上向下分别是:

  • checkpoint文件夹:作者训练camvid数据集后保存的ckpt
  • dataset:作者使用的camvid数据集
  • visualizations:被预测的原始动图(camvid)及预测后生成的语义分割动图,gif格式
  • enet.py定义了ENet的网络结构
  • get_class_weights.py计算类别权重脚本
  • predict_segmentation.py单独预测图片用脚本
  • preprocessing.py图片预处理脚本
  • test.sh将测试时所用的参数写好,方便测试的小脚本
  • test_enet.py测试数据集中测试集(test文件夹)的脚本
  • train.sh将训练时所用的参数写好,方便训练的小脚本
  • train_enet.py训练ENet的脚本

二、train_enet.py及相关脚本的理解

首先,对最主要的训练脚本进行研读。

1——11行,将所需的库及函数导入。

import tensorflow as tf
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
from tensorflow.python.platform import tf_logging as logging
from enet import ENet, ENet_arg_scope
from preprocessing import preprocess
from get_class_weights import ENet_weighing, median_frequency_balancing
import os
import time
import numpy as np
import matplotlib.pyplot as plt
slim = tf.contrib.slim

然后,使用tf.app.flags定义常用参数,这些定义好的参数可以在使用脚本时使用'脚本.py  --参数名 参数值'的方式进行修改。而使用flags时,flags的定义方式为:flags.数据类型(变量值名称,变量值,变量描述)

flags = tf.app.flags

#Directory arguments
# 数据集所在地址
flags.DEFINE_string('dataset_dir', './dataset', 'The dataset directory to find the train, validation and test images.')
# 保存训练数据地址
flags.DEFINE_string('logdir', './log/original', 'The log directory to save your checkpoint and event files.')
# 训练结束后是否保存一个eval_batch的图片
flags.DEFINE_boolean('save_images', True, 'Whether or not to save your images.')
# 训练过程是否将训练集及验证集合并进行训练
flags.DEFINE_boolean('combine_dataset', False, 'If True, combines the validation with the train dataset.')

#Training arguments
# 类别数
flags.DEFINE_integer('num_classes', 12, 'The number of classes to predict.')
# 训练的batch_size
flags.DEFINE_integer('batch_size', 10, 'The batch_size for training.')
# 每次验证的eval_batch_size
flags.DEFINE_integer('eval_batch_size', 25, 'The batch size used for validation.')
# 训练用图片的高
flags.DEFINE_integer('image_height', 360, "The input height of the images.")
# 训练用图片的宽
flags.DEFINE_integer('image_width', 480, "The input width of the images.")
# 训练的总轮数
flags.DEFINE_integer('num_epochs', 300, "The number of epochs to train your model.")
# 多少轮后学习率开始衰减
flags.DEFINE_integer('num_epochs_before_decay', 100, 'The number of epochs before decaying your learning rate.')
# 权重衰减
flags.DEFINE_float('weight_decay', 2e-4, "The weight decay for ENet convolution layers.")
# 学习率衰减
flags.DEFINE_float('learning_rate_decay_factor', 1e-1, 'The learning rate decay factor.')
# 初始化学习率
flags.DEFINE_float('initial_learning_rate', 5e-4, 'The initial learning rate for your training.')
# 所用的类别权重计算方式(MFB及ENet两种)
flags.DEFINE_string('weighting', "MFB", 'Choice of Median Frequency Balancing or the custom ENet class weights.')

#Architectural changes
# ENet中 初始化模块的个数(ENet论文中有描述)
flags.DEFINE_integer('num_initial_blocks', 1, 'The number of initial
  • 3
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 24
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值