4. /lib/networks/factory.py
调用函数链接:
- 调用网络的解读链接为:VGGnet_train.py
代码解读:
# --------------------------------------------------------
# SubCNN_TF
# Copyright (c) 2016 CVGL Stanford
# Licensed under The MIT License [see LICENSE for details]
# Written by Yu Xiang
# --------------------------------------------------------
"""Factory method for easily getting imdbs by name."""
__sets = {}
import networks.VGGnet_train
import networks.VGGnet_test
import pdb
import tensorflow as tf
#__sets['VGGnet_train'] = networks.VGGnet_train()
#__sets['VGGnet_test'] = networks.VGGnet_test()
def get_network(name):
"""Get a network by name."""
#if not __sets.has_key(name):
# raise KeyError('Unknown dataset: {}'.format(name))
#return __sets[name]
#根据给定的network_name来拆分,根据test/train位置取net的性质信息
if name.split('_')[1] == 'test':
return networks.VGGnet_test()
elif name.split('_')[1] == 'train':
#此时为训练,VGGnet_train类在/lib/networks/VGGnet_train.py中
#[VGGnet_train.py](https://blog.csdn.net/u014256231/article/details/79697581)
return networks.VGGnet_train()
else:
raise KeyError('Unknown dataset: {}'.format(name))
def list_networks():
"""List all registered imdbs."""
return __sets.keys()