Meta Learning ——MAML代码
import pickle
import numpy as np
import cv2
class DataLoader(object):
def __init__(self, pkl_file, batch_sz, n_way, k_shot, k_query, resize=224):
self._batch = [] # set batch
self._pkl_file = pkl_file
self._batch_sz = batch_sz # batch of set
self._n_way = n_way # n-way
self._k = k_shot + k_query
self._k_shot = k_shot # k-shot
self._k_query = k_query # for evaluation
self._sz = self._n_way * self._k
self.support_sz = self._n_way * self._k_shot # num of samples per set
self.query_sz = self._n_way * self._k_query # number of samples per set for evaluation
self._resize = resize
self._data = self._load_pkl(self._pkl_file)
self._cls_num = len(self._data) # class 数量
self._create_batch(self._batch_sz)
def _load_pkl(self, pkl_file):
"""
加载pickle文件
:param pkl_file:
:return:
"""
with open(pkl_file, 'rb') as f:
data = pickle.load(f)
return data
@staticmethod
def _get_section(data, index):
"""
获取所有部分
:param data:
:param index: data索引
:return:
"""
image = [data['image'][idx] for idx in index]
label = [data['label'][idx] for idx in index]
return image, label
def _create_batch(self, batch_sz):
"""
create batch id for meta-learning.
:param batch_sz: batch of set
:return:
"""
self._batch = []
for b in range(batch_sz): # for each batch(set)
# 1.select n_way classes randomly
selected_cls = np.random.choice(self._cls_num, self._n_way, False) # no duplicate 随机选n_way个类
np.random.shuffle(selected_cls)
sz = [] # n_way, k_shot, k_query
for cls in selected_cls:
# 2. select k_shot + k_query for each class
selected_imgs_idx = np.random.choice(len(self._data[cls]['image']), self._k_shot + self._k_query, False)
np.random.shuffle(selected_imgs_idx)
section = self._get_section(self._data[cls], selected_imgs_idx)
sz.append(section)
self._batch.append(sz) # append set to current sets
def _load_img(self, path):
"""
读取图片
:param path:
:return:
"""
img = cv2.imread(path)
img = cv2.resize(img, (self._resize, self._resize))
img = img / 225 * 2 - 1
return img
def _get_task(self, task):
"""
获取一个task
:param task: 当前task [list]
:return:
"""
image = np.zeros((self._sz, self._resize, self._resize, 3))
label = np.zeros((self._sz, 2))
image_ = [task[i][0][j] for i in range(self._n_way) for j in range(self._k)]
label_ = [task[i][3][j] for i in range(self._n_way) for j in range(self._k)]
for i in range(self._sz):
image[i] = self._load_img(image_[i])
label[i] = label_[i]
return image, label
def get_task_batch(self, task_num, cursor):
"""
获取batch task
:param task_num: 一个batch的task数量 [int]
:param cursor: task计数器 [int]
:return:
"""
image = np.zeros((self._sz * task_num, self._resize, self._resize, 3))
label = np.zeros((self._sz * task_num, 2))
for i in range(task_num):
if cursor >= self._batch_sz:
cursor = 0
image[self._sz * i:self._sz * (i + 1)] = self._get_task(self._batch[cursor])[0]
label[self._sz * i:self._sz * (i + 1)] = self._get_task(self._batch[cursor])[1]
cursor += 1
support = (
image[:self.support_sz * task_num],
label[:self.support_sz * task_num])
query = (
image[self.support_sz * task_num:],
label[self.support_sz * task_num:])
return cursor, support, query
- 下面是元学习训练框架的基础代码,可以按需求补充和修改。
import tensorflow as tf
from tensorflow.contrib.layers.python import layers
class Maml(object):
def __init__(self, imsize, task_num, sz_support, sz_query, learning_rate, is_training=True):
self.imSize = imsize
self.task_num = task_num
self.sz_p = sz_support # 每个task的support数量
self.sz_q = sz_query # 每个task的query数量
self.learning_rate = learning_rate
self.is_training = is_training
self.num_updates = 5 # 训练时support更新次数
self.val_num_updates = 15 # 测试时support更新次数
self.update_lr = 0.01
self.weights = self.construct_weights()
self.build_model()
def _tf_slice(self, image, label, i, sz):
image_ = tf.slice(image, [i * sz, 0, 0, 0], [sz, -1, -1, -1])
label_ = tf.slice(label, [i * sz, 0], [sz, -1])
return image_, label_
def build_model(self):
"""初始化模型"""
# 初始化placeholder
self._image = tf.compat.v1.placeholder(tf.float32, [None, self.imSize, self.imSize, 3], name='image')
self._label = tf.compat.v1.placeholder(tf.float32, [None, 2], name='label')
self._image_q = tf.compat.v1.placeholder(tf.float32, [None, self.imSize, self.imSize, 3], name='image_q')
self._label_q = tf.compat.v1.placeholder(tf.float32, [None, 2], name='label_q')
def meta_learning():
"""元学习 用于训练"""
loss_q = [0.0 for _ in range(self.num_updates)]
# 每个task分别进行num_updates次更新
for i in range(self.task_num):
image, label = self._tf_slice(self._image, self._label, i, self.sz_p)
image_q, label_q = self._tf_slice(self._image_q, self._label_q, i, self.sz_q)
pred = self._forward(image, self.weights) # 对support进行预测
loss = self._cal_loss(pred, label) # 计算support损失
grads = tf.gradients(loss, list(self.weights.values())) # 求梯度
gradients = dict(zip(self.weights.keys(), grads))
fast_weights = dict(zip(self.weights.keys(),
[self.weights[key] - self.update_lr * gradients[key] for key in
self.weights.keys()])) # 得到更新后的权重
pred_q = self._forward(image_q, fast_weights) # 用support更新后的权重对query进行预测
loss_q[0] += self._cal_loss(pred_q, label_q) # 计算query损失
for j in range(self.num_updates - 1):
pred = self._forward(image, fast_weights)
loss = self._cal_loss(pred, label)
grads = tf.gradients(loss, list(fast_weights.values()))
gradients = dict(zip(fast_weights.keys(), grads))
fast_weights = dict(zip(fast_weights.keys(),
[fast_weights[key] - self.update_lr * gradients[key] for key
in fast_weights.keys()]))
pred_q = self._forward(image_q, fast_weights)
loss_q[j + 1] += self._cal_loss(pred_q, label_q)
return loss_q
def meta_finetune():
"""元学习 用于测试"""
output_q, error_q = [], []
pred = self._forward(self._image, self.weights)
loss = self._cal_loss(pred, self._label)
grads = tf.gradients(loss, list(self.weights.values()))
gradients = dict(zip(self.weights.keys(), grads))
fast_weights = dict(zip(self.weights.keys(),
[self.weights[key] - self.update_lr * gradients[key] for key in
self.weights.keys()]))
pred_q = self._forward(self._image_q, fast_weights)
output_q.append(pred_q)
error_q.append(self._eval(pred_q, self._label_q))
for j in range(self.val_num_updates - 1):
pred = self._forward(self._image, fast_weights)
loss = self._cal_loss(pred, self._label)
grads = tf.gradients(loss, list(fast_weights.values()))
gradients = dict(zip(fast_weights.keys(), grads))
fast_weights = dict(zip(fast_weights.keys(),
[fast_weights[key] - self.update_lr * gradients[key] for key
in fast_weights.keys()]))
pred_q = self._forward(self._image_q, fast_weights)
output_q.append(pred_q)
error_q.append(self._eval(pred_q, self._label_q))
return output_q, error_q
loss_q = meta_learning()
self.loss = loss_q[-1] / self.task_num
output_q, error_q = meta_finetune()
self.pred_q = output_q[-1]
self.error = error_q[-1]
self.global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(self.learning_rate)
gvs = optimizer.compute_gradients(self.loss)
self.train_op = optimizer.apply_gradients(gvs)
def construct_weights(self):
""" 用原始方法获取权重 """
weights = {}
# 卷积层
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32)
weights['conv1'] = tf.get_variable('conv1', [5, 5, 3, 32], initializer=conv_initializer, dtype=tf.float32)
weights['convb1'] = tf.Variable(tf.zeros([32]), name='convb1')
# 全连接层
weights['w1'] = tf.Variable(tf.truncated_normal([96, 128], stddev=0.01), name='w1')
weights['b1'] = tf.Variable(tf.zeros([128]), name='b1')
return weights
def _conv_block(self, inp, conv_w, conv_b, stride, scope, max_pool2d=True):
with tf.variable_scope(scope):
output = tf.nn.conv2d(inp, conv_w, stride, padding='SAME') + conv_b
output = tf.nn.relu(output)
if max_pool2d:
output = tf.nn.max_pool(output, [2, 2], [1, 2, 2, 1], padding='SAME')
return output
def _forward_conv(self, inp, weights, scope='conv'):
conv = self._conv_block(inp, weights['conv1'], weights['convb1'], [1, 2, 2, 1], scope)
conv = layers.flatten(conv, scope="flatten")
return conv
def _forward_fc(self, inp, w, b, activation=tf.nn.relu, scope='fc'):
with tf.variable_scope(scope):
fc = tf.matmul(inp, w) + b
fc = activation(fc)
return fc
def _forward(self, image, weights):
feature = self._forward_conv(image, weights)
fc = self._forward_fc(feature, weights['w1'], weights['b1'])
return fc
def train(self, support, query, num_train_steps):
with tf.Session() as sess:
# 训练模型
for i in range(num_train_steps):
image, label = support
image_q, label_q = query
feed_dict = {
self._image: image,
self._label: label,
self._image_q: image_q,
self._label_q: label_q
}
train_loss, _ = sess.run([self.loss, self.train_op], feed_dict=feed_dict)
sess.close()
def _eval(self, pred, label):
pass
def _cal_loss(self, pred, label):
pass
- 完成本次实验后的一些经验和想法:
- 一些参数设置:task_num=5, batch_sz=10000,n_way=5, k_shot=5,k_query=20/25,learning_rate=0.001
- 适用于分类越多(本实验76类),每类样本数量不均衡或样本数量少,这样的情况下训练效果可能会超过通用训练方式。
- 时间成本较高,上线或落地有些困难。
- 每次的结果有些不稳定。