目前你已经看完了,会改数据输入部分的代码了。
那么模型代码还需要看。
目录
train.py包含了方法:
_build_deeplab
_tower_loss
_average_gradients
_log_summaries
_train_deeplab_model
main
调用关系:
train.py :main->_train_deeplab_model->tower_loss->_build_deeplab
train.py:_train_deeplab_model->_average_gradient
train.py:_build_deeplab->_log_summaries
model.py包含的方法:
get_extra_layer_scopes
predict_labels_multi_scale
predict_labels
multi_scale_logits
extract_features
_get_logits
refine_by_decoder
get_branch_logits
model.py被调用,()表示跨py文件调用
1.(train.py _build_deeplab)->(model.py multi_scale_logits):
model.py: multi_scale_logit->_get_logits->extract_features+refine_by_decoder+get_branch_logit
2.(train.py _train_deeplab_model)->(model.py get_extra_layer_scopes)
3.(train.py main)->(model.py get_extra_layer_scopes)
数据的传递:
首先在train.py中:
main函数生成了dataset实例:
dataset = data_generator.Dataset(
dataset_name=FLAGS.dataset,
split_name=FLAGS.train_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=clone_batch_size,
crop_size=[int(sz) for sz in FLAGS.train_crop_size],
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
min_scale_factor=FLAGS.min_scale_factor,
max_scale_factor=FLAGS.max_scale_factor,
scale_factor_step_size=FLAGS.scale_facto