HSML代码笔记

主要创新点

不管是train还是test,对每个任务我们都使用HSML,而不是仅仅在train后,对test使用新的task的时候使用HSML。理解有误。

首先确定簇的个数,但是这个簇是动态变换的,如果新任务的分布和之前的簇的差别过大就新建一个簇,使用通用的meta-knowledge,否则就是将其分配给已经存在的一个簇。对于测试的时候也是这样。首先判断一个新任务和那个簇比较接近。在对其进行更新。

基础知识

FLAGS

在main.py中定义了FLAGS,FLAGS = flags.FLAGS,在其他的py文件中也可以直接使用,本文中的maml.py直接在import之后调用了FLAGS = flags.FLAGS

各个部分代码的含义

image_embedding.py

1.local3,local4分别代表什么

data_generator.py

返回所有batch的数据
制作support,query,对数据进行预处理。

data_generator = DataGenerator(FLAGS.update_batch_size + 15,
                                                   FLAGS.meta_batch_size) 
    if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset', 'multidataset_leave_one_out']:
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            random.seed(5)
            if FLAGS.datasource in ['miniimagenet', 'omniglot']:
                # return all_image_batches, all_label_batches
                image_tensor, label_tensor = data_generator.make_data_tensor()

maml.py

完成的主要工作

1.对模型的进行初始化,初始化loss,权重
2.使用了construct_model函数,这个函数完成了主要的任务

model.construct_model完成的主要工作是,
首先将数据通过lstm_tree.py转换成论文中所介绍的形式(Task Representation Learning),对图片数据按照文中说明的方式进行处理,以及分簇,task-adaptation也要在此进行,并且元学习的过程此部分完成。

def construct_model

1.将我们图片和标签通过进行一个类似编码的过程。

 # 提取图片特征。
                        input_task_emb = self.image_embed.model(tf.reshape(inputa,
                                                                           [-1, self.img_size, self.img_size,
                                                                            self.channels]))
                        one_hot_labela = tf.squeeze(
                            tf.one_hot(tf.to_int32(labela), depth=1, axis=-1))

                    input_task_emb = tf.concat((input_task_emb, one_hot_labela), axis=-1)

2.用文中提到的方式对我们的数据进行编码

task_embed_vec, task_emb_loss = self.lstmae.model(input_task_emb)

3.对输入分簇并且进行task adaptation,都是在lstm_tree.py中完成。

                _, meta_knowledge_h = self.tree.model(task_embed_vec)
                task_enhanced_emb_vec = tf.concat([task_embed_vec, meta_knowledge_h], axis=1)

4.对输入应用经过HSML处理得到的权重。对每个输入都是先经过HSML,在进行前向传播。

 # 对输入应用task-weights
                task_outputa = self.forward(inputa, task_weights, reuse=reuse)
                task_lossa = self.loss_func(task_outputa, labela)

5.开始计算loss等信息
6.参数更新
7.绘制loss曲线

不懂

  1. 位于maml.py中的143行左右
                with tf.variable_scope('task_specific_mapping', reuse=tf.AUTO_REUSE):
                    eta = []
                    for key in weights.keys():
                        weight_size = np.prod(weights[key].get_shape().as_list())
                        eta.append(tf.reshape(
                            tf.layers.dense(task_enhanced_emb_vec, weight_size, activation=tf.nn.sigmoid,
                                            name='eta_{}'.format(key)), tf.shape(weights[key])))
                    eta = dict(zip(weights.keys(), eta))

                    task_weights = dict(zip(weights.keys(), [weights[key] * eta[key] for key in weights.keys()]))

                task_outputbs, task_lossesb = [], []

lstm_tree.py

论文细节

Task Representation Learning

  • Recurrent Autoencoder Aggregator
  • Pooling Autoencoder Aggregator
    实验表明,RAA平均效果比PAA要好。
    1.首先在maml.py中进行调用
    input_task_emb是我们已经编码好的数据以及标签。
task_embed_vec, task_emb_loss = self.lstmae.model(input_task_emb)

具体的操作过程在task_embedding中LSTMAutoencoder类中。

Hierarchical Task Clustering

TreeLSTM.py是做什么的?

存在于maml.py中,首次在MAML.__init__中进行声明self.tree = TreeLSTM(input_dim=FLAGS.hidden_dim, tree_hidden_dim=FLAGS.hidden_dim)
接下来在maml.py中的_, meta_knowledge_h = self.tree.model(task_embed_vec)处进行调用,对输入进行分簇处理
在这里插入图片描述

Knowledge Adaptation

在这里插入图片描述

tf.slice 介绍

tf.train.string_input_producer

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值