自定义模型
整合自定义模型需要实现一个ModelHelper
类。ModelHelper
包含数据输入管道和网络前向传播和损失函数的定义。使用自定义的ModelHelper
,网络可以无限制的使用FullPrecLearner
训练,或者使用其他如ChannelPrunedLearner
和UniformQuantTFLearner
学习器进行通道裁剪和量化。
要点概括
要完成训练需要提供下面两个组件:
- 数据输入管道
- 网络定义
ModelHelper
是抽象类AbstractModelHelper
的子类,它涉及来提供上面这些定义。在PocketFlow,已经提供了几个ModelHelper
类来描述数据集和模型结构的不同组合。要使用自定义模型,需要实现新的ModelHelper
类。此外,还需要一个执行脚本来调用这个新定义的ModelHelper
类。
数据输入管道
首先,需要告诉PocketFlow如何解析数据文件。在examples/fmnist_dataset.py
定义一个名为FmnistDataset
类创建返回训练和测试子集的迭代器。
from dataset.abstract_dataset import AbstractDataset
FLAGS = tf.app.flags.FLAGS
#将图像和标签读入内存
def load_mnist(image_file,label_file):
#...
return image,labels
def parse_fn(image,label,is_train):
"""
输入:
进行相关预处理,resize,crop,等
输出:image:图像张量
label:one-hot标签张量
"""
return image,label
class FMnistDataset(AbstractDataset):
'''数据集管道类'''
def __init__(self,is_train):
super(FMnistDataset,self).__init__(is_train)
#...
if is_train:
self.batch_size = FLAGS.batch_size
image_file = os.path.join(data_dir, 'train images file name')
label_file = os.path.join(data_dir, "train lables file name")
else:
#测试集,同上
self.images, self.labels = load_mnist(iamge_file,label_file)
self.parse_fn = lambda x,y : parse_fn(x,y,is_train)
def build(self,enbl_trn_val_split=False):
"""
构建tf.data.Dataset()的迭代器
"""
# create a tf.data.Dataset() object from NumPy arrays
dataset = tf.data.Dataset.from_tensor_slices((self.images, self.labels))
dataset = dataset.map(self.parse_fn, num_parallel_calls=FLAGS.nb_threads)
# create iterators for training & validation subsets separately
if self.is_train and enbl_trn_val_split:
iterator_val = self.__make_iterator(dataset.take(FLAGS.nb_smpls_val))
iterator_trn = self.__make_iterator(dataset.skip(FLAGS.nb_smpls_val))
return iterator_trn, iterator_val
return self.__make_iterator(dataset)
def __make_iterator(self, dataset