PocketFlow自定义模型和数据集

自定义模型

整合自定义模型需要实现一个ModelHelper类。ModelHelper包含数据输入管道和网络前向传播和损失函数的定义。使用自定义的ModelHelper,网络可以无限制的使用FullPrecLearner训练,或者使用其他如ChannelPrunedLearnerUniformQuantTFLearner学习器进行通道裁剪和量化。

要点概括

要完成训练需要提供下面两个组件:

  • 数据输入管道
  • 网络定义

ModelHelper是抽象类AbstractModelHelper的子类,它涉及来提供上面这些定义。在PocketFlow,已经提供了几个ModelHelper类来描述数据集和模型结构的不同组合。要使用自定义模型,需要实现新的ModelHelper类。此外,还需要一个执行脚本来调用这个新定义的ModelHelper类。

数据输入管道

首先,需要告诉PocketFlow如何解析数据文件。在examples/fmnist_dataset.py定义一个名为FmnistDataset类创建返回训练和测试子集的迭代器。

from dataset.abstract_dataset import AbstractDataset
FLAGS = tf.app.flags.FLAGS

#将图像和标签读入内存
def load_mnist(image_file,label_file):
    #...   
    return image,labels
def parse_fn(image,label,is_train):
    """
    输入:
    进行相关预处理,resize,crop,等
    输出:image:图像张量
    	 label:one-hot标签张量
    """
    return image,label

class FMnistDataset(AbstractDataset):
    '''数据集管道类'''
    def __init__(self,is_train):
        super(FMnistDataset,self).__init__(is_train)
        #...
        if is_train:
            self.batch_size = FLAGS.batch_size
            image_file = os.path.join(data_dir, 'train images file name')
            label_file = os.path.join(data_dir, "train lables file name")
         else:
            #测试集,同上
        self.images, self.labels = load_mnist(iamge_file,label_file)
        self.parse_fn = lambda x,y : parse_fn(x,y,is_train)
        
    def build(self,enbl_trn_val_split=False):
        """
        构建tf.data.Dataset()的迭代器
        """
          # create a tf.data.Dataset() object from NumPy arrays
    	dataset = tf.data.Dataset.from_tensor_slices((self.images, self.labels))
    	dataset = dataset.map(self.parse_fn, num_parallel_calls=FLAGS.nb_threads)

    	# create iterators for training & validation subsets separately
    	if self.is_train and enbl_trn_val_split:
      		iterator_val = self.__make_iterator(dataset.take(FLAGS.nb_smpls_val))
      		iterator_trn = self.__make_iterator(dataset.skip(FLAGS.nb_smpls_val))
      		return iterator_trn, iterator_val

    	return self.__make_iterator(dataset)
    
     def __make_iterator(self, dataset
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值