tflearn重载预训练模型时的bug及修复

tflearn重载预训练模型时的bug及修复

最近接触了tflearn框架,是一个集合了tensorflow的一些功能较为高层的API。模型搭建和训练也比tensorflow的语句精炼,但就是容易出bug后找不到原因。比如最近在重载预训练模型时就出现了一个莫名其妙的bug。google找不到类似的经历,进包自带的函数里调试才发现了问题所在,记录下来也许能帮助到有用的同志。

问题描述

最近在做的是从一个生物信号S里重建图像,思路就是先用AutoEncoder提图像特征ft,再把生物信号S往ft做回归,再从ft上采样,重建图像。用此种方法作为后续研究的baseline。
无奈手里与生物信号成对的图像数量很少,最多也只有不到2000,自己设计训练vae或者ae都效果不佳。于是往github上找模型,找到一个在MSCOCO+ Flickr 30k. 总共约 200k图像上进行预训练的autoencoder模型,地址。测试时可以接受jpg或者png图像,程序会reshape到256*256.train_encoder.py里的模型为:

def build_model():
    # logging.info('building model')
    img_prep = ImagePreprocessing()
    img_prep.add_featurewise_zero_center()
    img_prep.add_featurewise_stdnorm()

    encoder = input_data(shape=(None, IMAGE_INPUT_SIZE[0], IMAGE_INPUT_SIZE[1],
                                3), data_preprocessing=img_prep)
    encoder = conv_2d(encoder, 16, 7, activation='relu')    # incoming, nb_filter, filter_size, strides=1, padding='same', activation='linear',
    encoder = dropout(encoder, 0.25)  # you can have noisy input instead
    encoder = max_pool_2d(encoder, 2)
    encoder = conv_2d(encoder, 16, 7, activation='relu')
    encoder = max_pool_2d(encoder, 2)
    encoder = conv_2d(encoder, 8, 7, activation='relu')
    encoder = max_pool_2d(encoder, 2)

    decoder = conv_2d(encoder, 8, 7, activation='relu')
    decoder = upsample_2d(decoder, 2)
    decoder = conv_2d(decoder, 16, 7, activation='relu')
    decoder = upsample_2d(decoder, 2)
    decoder = conv_2d(decoder, 16, 7, activation='relu')
    decoder = upsample_2d(decoder, 2)
    decoder = conv_2d(decoder, 3, 7)

    return regression(decoder, optimizer='adadelta',
                      loss='binary_crossentropy', learning_rate=0.005)

evaluate.py里导入train_encoder.py里的build_model函数,并从对应的checkpoints里导入模型的权重参数

net = train_encoder.bulid_model()
model = tflearn.DNN(net)
checkpoint_path = './checkpoints/model.h5'
model.load(checkpoint_path)

我们也可以只重载encoder模块,这比后面的只重载decoder模块要容易:

encoder = input_data(shape=(None, IMAGE_INPUT_SIZE[0], IMAGE_INPUT_SIZE[1],
                            3))# , data_preprocessing=img_prep, name='image_input')
encoder = conv_2d(encoder, 16, 7, activation='relu')    # incoming, nb_filter, filter_size, strides=1, padding='same', activation='linear',
encoder = dropout(encoder, 0.25)  # you can have noisy input instead
encoder = max_pool_2d(encoder, 2)
encoder = conv_2d(encoder, 16, 7, activation='relu')
encoder = max_pool_2d(encoder, 2)
encoder = conv_2d(encoder, 8, 7, activation='relu')
encoder_model = tflearn.DNN(encoder, session=model.session)

在我们需要从自己的ft里,生成图像时,我们只需要重载decoder。只重载decoder模块时,问题就出现了。

z = input_data(shape=(None,32,32,8), name='z')
decoder = conv_2d(z, 8, 7, activation='relu', scope='Conv2D_3', reuse=True)
decoder = upsample_2d(decoder, 2, )
decoder = conv_2d(decoder, 16, 7, activation='relu', scope='Conv2D_4', reuse=True)
decoder = upsample_2d(decoder, 2, )
decoder = conv_2d(decoder, 16, 7, activation='relu', scope='Conv2D_5', reuse=True)
decoder = upsample_2d(decoder, 2, )
decoder = conv_2d(decoder, 3, 7, scope='Conv2D_6', reuse=True)
generator_model = tflearn.DNN(decoder, session=model.session)
recons_img = generator_model.predict({'z':ft})

ft就是我从生物信号S回归出的特征。作为生成模型的输入。但运行时出现了error

KeyError: <tf.Tensor 'InputData/X:0' shape=(?, 256, 256, 3) dtype=float32>

调试一下, 怎么generator_model的inputs里面有两个tensor?一个的shape是(None, 32,32,8),还有一个就是shape=(None,256,256,3)是整个model的input_data层的shape。我以为是自己代码的编写有问题,但是参考了github上的tflearn教程里的一个例子,地址。发现里面的generator_model里也有两个tensor但是运行不会报错。郁闷了很久,最后还是跟着error Traceback进到产生错误的根源位置=>"/nfs/software/anaconda3/lib/python3.5/site-packages/tflearn/helpers/evaluator.py"

class Evaluator(object):

    def __init__(self, tensors, model=None, session=None):
        self.tensors = to_list(tensors)
        self.graph = self.tensors[0].graph
        self.model = model
        self.dprep_collection = tf.get_collection(tf.GraphKeys.DATA_PREP)
        self.inputs = tf.get_collection(tf.GraphKeys.INPUTS)

        with self.graph.as_default():
            self.session = tf.Session()
            if session: self.session = session
            self.saver = tf.train.Saver()
            if model: self.saver.restore(self.session, model)
            
    def predict(self, feed_dict):
        with self.graph.as_default():
            # Data Preprocessing
            dprep_dict = dict()
            for i in range(len(self.inputs)):
                # Support for custom inputs not using dprep/daug
                if len(self.dprep_collection) > i:
                    if self.dprep_collection[i] is not None:
                        dprep_dict[self.inputs[i]] = self.dprep_collection[i]
            # Apply pre-processing
            if len(dprep_dict) > 0:
                for k in dprep_dict:
                    feed_dict[k] = dprep_dict[k].apply(feed_dict[k])

            # Prediction for each tensor
            tflearn.is_training(False, self.session)
            prediction = []
            if len(self.tensors) == 1:
                return self.session.run(self.tensors[0], feed_dict=feed_dict)
            else:
                for output in self.tensors:
                    o_pred = self.session.run(output, feed_dict=feed_dict).tolist()
                    for i, val in enumerate(o_pred): # Reshape pred per sample
                        if len(self.tensors) > 1:
                            if not len(prediction) > i: prediction.append([])
                            prediction[i].append(val)
                return prediction

注意这里我代码里的feed_dict就是{'z':ft},明确指定了两个tensor中的一个,但是为什么它就没有指定我指定的那个tensor呢?发现feed_dictlen(dprep_dict) > 0时会有个改动,而这个dpre_dict是否长度为0,则与self.dprep_collection[i] is not None这个判断条件True还是False有关。调试时tflearn给的例子里,dprep_collection都是None元素,而我的代码里有一个元素跟ImagePreprocessing有关,仔细一看,model 的input_data里的一个参数是data_preprocessing=img_prep,把它注释掉以后,error就消失了!

最后推荐一个入门tflearn的项目,地址放在这里。examples里面有很多简单的项目代码,可以从中学到如何搭建模型,如何训练模型,如何finetune等等。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
# 工程内容 这个程序是基于tensorflow的tflearn库实现部分RCNN功能。 # 开发环境 windows10 + python3.5 + tensorflow1.2 + tflearn + cv2 + scikit-learn # 数据集 采用17flowers据集, 官网下载:http://www.robots.ox.ac.uk/~vgg/data/flowers/17/ # 程序说明 1、setup.py---初始化路径 2、config.py---配置 3、tools.py---进度条和显示带框图像工具 4、train_alexnet.py---大数据集预训练Alexnet网络,140个epoch左右,bitch_size为64 5、preprocessing_RCNN.py---图像的处理(选择性搜索、数据存取等) 6、selectivesearch.py---选择性搜索源码 7、fine_tune_RCNN.py---小数据集微调Alexnet 8、RCNN_output.py---训练SVM并测试RCNN(测试的候测试图片选择第7、16类中没有参与训练的,单朵的花效果好,因为训练用的都是单朵的) # 文件说明 1、train_list.txt---预训练数据,数据在17flowers文件夹中 2、fine_tune_list.txt---微调数据2flowers文件夹中 3、1.png---直接用选择性搜索的区域划分 4、2.png---通过RCNN后的区域划分 # 程序问题 1、由于数据集小的原因,在微调候并没有像论文一样按一个bitch32个正样本,128个负样本输入,感觉正样本过少; 2、还没有懂最后是怎么给区域打分的,所有非极大值抑制集合canny算子没有进行,待续; 3、对选择的区域是直接进行缩放的; 4、由于数据集合论文采用不一样,但是微调和训练SVM采用的IOU阈值一样,有待调参。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值