主要创新点
不管是train还是test,对每个任务我们都使用HSML,而不是仅仅在train后,对test使用新的task的时候使用HSML。理解有误。
首先确定簇的个数,但是这个簇是动态变换的,如果新任务的分布和之前的簇的差别过大就新建一个簇,使用通用的meta-knowledge,否则就是将其分配给已经存在的一个簇。对于测试的时候也是这样。首先判断一个新任务和那个簇比较接近。在对其进行更新。
基础知识
FLAGS
在main.py中定义了FLAGS,FLAGS = flags.FLAGS,在其他的py文件中也可以直接使用,本文中的maml.py直接在import之后调用了FLAGS = flags.FLAGS
各个部分代码的含义
image_embedding.py
1.local3,local4分别代表什么
data_generator.py
返回所有batch的数据
制作support,query,对数据进行预处理。
data_generator = DataGenerator(FLAGS.update_batch_size + 15,
FLAGS.meta_batch_size)
if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset', 'multidataset_leave_one_out']:
tf_data_load = True
num_classes = data_generator.num_classes
if FLAGS.train: # only construct training model if needed
random.seed(5)
if FLAGS.datasource in ['miniimagenet', 'omniglot']:
# return all_image_batches, all_label_batches
image_tensor, label_tensor = data_generator.make_data_tensor()
maml.py
完成的主要工作
1.对模型的进行初始化,初始化loss,权重
2.使用了construct_model函数,这个函数完成了主要的任务
model.construct_model
完成的主要工作是,
首先将数据通过lstm_tree.py
转换成论文中所介绍的形式(Task Representation Learning),对图片数据按照文中说明的方式进行处理,以及分簇,task-adaptation也要在此进行,并且元学习的过程此部分完成。
def construct_model
1.将我们图片和标签通过进行一个类似编码的过程。
# 提取图片特征。
input_task_emb = self.image_embed.model(tf.reshape(inputa,
[-1, self.img_size, self.img_size,
self.channels]))
one_hot_labela = tf.squeeze(
tf.one_hot(tf.to_int32(labela), depth=1, axis=-1))
input_task_emb = tf.concat((input_task_emb, one_hot_labela), axis=-1)
2.用文中提到的方式对我们的数据进行编码
task_embed_vec, task_emb_loss = self.lstmae.model(input_task_emb)
3.对输入分簇并且进行task adaptation,都是在lstm_tree.py中完成。
_, meta_knowledge_h = self.tree.model(task_embed_vec)
task_enhanced_emb_vec = tf.concat([task_embed_vec, meta_knowledge_h], axis=1)
4.对输入应用经过HSML处理得到的权重。对每个输入都是先经过HSML,在进行前向传播。
# 对输入应用task-weights
task_outputa = self.forward(inputa, task_weights, reuse=reuse)
task_lossa = self.loss_func(task_outputa, labela)
5.开始计算loss等信息
6.参数更新
7.绘制loss曲线
不懂
- 位于maml.py中的143行左右
with tf.variable_scope('task_specific_mapping', reuse=tf.AUTO_REUSE):
eta = []
for key in weights.keys():
weight_size = np.prod(weights[key].get_shape().as_list())
eta.append(tf.reshape(
tf.layers.dense(task_enhanced_emb_vec, weight_size, activation=tf.nn.sigmoid,
name='eta_{}'.format(key)), tf.shape(weights[key])))
eta = dict(zip(weights.keys(), eta))
task_weights = dict(zip(weights.keys(), [weights[key] * eta[key] for key in weights.keys()]))
task_outputbs, task_lossesb = [], []
lstm_tree.py
论文细节
Task Representation Learning
- Recurrent Autoencoder Aggregator
- Pooling Autoencoder Aggregator
实验表明,RAA平均效果比PAA要好。
1.首先在maml.py中进行调用
input_task_emb是我们已经编码好的数据以及标签。
task_embed_vec, task_emb_loss = self.lstmae.model(input_task_emb)
具体的操作过程在task_embedding中LSTMAutoencoder类中。
Hierarchical Task Clustering
TreeLSTM.py是做什么的?
存在于maml.py中,首次在MAML.__init__中进行声明self.tree = TreeLSTM(input_dim=FLAGS.hidden_dim, tree_hidden_dim=FLAGS.hidden_dim)
接下来在maml.py中的_, meta_knowledge_h = self.tree.model(task_embed_vec)
处进行调用,对输入进行分簇处理