Meta Learning ——MAML代码

Meta Learning ——MAML代码

  • 下面是元学习数据构造部分的代码
import pickle
import numpy as np
import cv2


class DataLoader(object):
    def __init__(self, pkl_file, batch_sz, n_way, k_shot, k_query, resize=224):
        self._batch = []  # set batch
        self._pkl_file = pkl_file
        self._batch_sz = batch_sz  # batch of set
        self._n_way = n_way  # n-way
        self._k = k_shot + k_query
        self._k_shot = k_shot  # k-shot
        self._k_query = k_query  # for evaluation
        self._sz = self._n_way * self._k
        self.support_sz = self._n_way * self._k_shot  # num of samples per set
        self.query_sz = self._n_way * self._k_query  # number of samples per set for evaluation
        self._resize = resize
        self._data = self._load_pkl(self._pkl_file)
        self._cls_num = len(self._data)  # class 数量
        self._create_batch(self._batch_sz)

    def _load_pkl(self, pkl_file):
        """
        加载pickle文件
        :param pkl_file:
        :return:
        """
        with open(pkl_file, 'rb') as f:
            data = pickle.load(f)
        return data

    @staticmethod
    def _get_section(data, index):
        """
        获取所有部分
        :param data:
        :param index: data索引
        :return:
        """

        image = [data['image'][idx] for idx in index]
        label = [data['label'][idx] for idx in index]
        return image, label

    def _create_batch(self, batch_sz):
        """
        create batch id for meta-learning.
        :param batch_sz: batch of set
        :return:
        """
        self._batch = []
        for b in range(batch_sz):  # for each batch(set)
            # 1.select n_way classes randomly
            selected_cls = np.random.choice(self._cls_num, self._n_way, False)  # no duplicate 随机选n_way个类
            np.random.shuffle(selected_cls)
            sz = []  # n_way, k_shot, k_query
            for cls in selected_cls:
                # 2. select k_shot + k_query for each class
                selected_imgs_idx = np.random.choice(len(self._data[cls]['image']), self._k_shot + self._k_query, False)
                np.random.shuffle(selected_imgs_idx)
                section = self._get_section(self._data[cls], selected_imgs_idx)
                sz.append(section)
            self._batch.append(sz)  # append set to current sets

    def _load_img(self, path):
        """
        读取图片
        :param path:
        :return:
        """
        img = cv2.imread(path)
        img = cv2.resize(img, (self._resize, self._resize))
        img = img / 225 * 2 - 1
        return img

    def _get_task(self, task):
        """
        获取一个task
        :param task: 当前task [list]
        :return:
        """

        image = np.zeros((self._sz, self._resize, self._resize, 3))
        label = np.zeros((self._sz, 2))

        image_ = [task[i][0][j] for i in range(self._n_way) for j in range(self._k)]
        label_ = [task[i][3][j] for i in range(self._n_way) for j in range(self._k)]

        for i in range(self._sz):
            image[i] = self._load_img(image_[i])
            label[i] = label_[i]

        return image, label

    def get_task_batch(self, task_num, cursor):
        """
        获取batch task
        :param task_num: 一个batch的task数量 [int]
        :param cursor: task计数器 [int]
        :return:
        """

        image = np.zeros((self._sz * task_num, self._resize, self._resize, 3))
        label = np.zeros((self._sz * task_num, 2))

        for i in range(task_num):
            if cursor >= self._batch_sz:
                cursor = 0
            image[self._sz * i:self._sz * (i + 1)] = self._get_task(self._batch[cursor])[0]
            label[self._sz * i:self._sz * (i + 1)] = self._get_task(self._batch[cursor])[1]
            cursor += 1
        support = (
            image[:self.support_sz * task_num],
            label[:self.support_sz * task_num])
        query = (
            image[self.support_sz * task_num:],
            label[self.support_sz * task_num:])
        return cursor, support, query
  • 下面是元学习训练框架的基础代码,可以按需求补充和修改。
import tensorflow as tf
from tensorflow.contrib.layers.python import layers


class Maml(object):

    def __init__(self, imsize, task_num, sz_support, sz_query, learning_rate, is_training=True):
        self.imSize = imsize
        self.task_num = task_num
        self.sz_p = sz_support  # 每个task的support数量
        self.sz_q = sz_query  # 每个task的query数量
        self.learning_rate = learning_rate
        self.is_training = is_training
        self.num_updates = 5  # 训练时support更新次数
        self.val_num_updates = 15  # 测试时support更新次数
        self.update_lr = 0.01
        self.weights = self.construct_weights()
        self.build_model()

    def _tf_slice(self, image, label, i, sz):
        image_ = tf.slice(image, [i * sz, 0, 0, 0], [sz, -1, -1, -1])
        label_ = tf.slice(label, [i * sz, 0], [sz, -1])
        return image_, label_

    def build_model(self):
        """初始化模型"""
        # 初始化placeholder
        self._image = tf.compat.v1.placeholder(tf.float32, [None, self.imSize, self.imSize, 3], name='image')
        self._label = tf.compat.v1.placeholder(tf.float32, [None, 2], name='label')

        self._image_q = tf.compat.v1.placeholder(tf.float32, [None, self.imSize, self.imSize, 3], name='image_q')
        self._label_q = tf.compat.v1.placeholder(tf.float32, [None, 2], name='label_q')

        def meta_learning():
            """元学习 用于训练"""
            loss_q = [0.0 for _ in range(self.num_updates)]
            # 每个task分别进行num_updates次更新
            for i in range(self.task_num):
                image, label = self._tf_slice(self._image, self._label, i, self.sz_p)
                image_q, label_q = self._tf_slice(self._image_q, self._label_q, i, self.sz_q)
                pred = self._forward(image, self.weights)  # 对support进行预测
                loss = self._cal_loss(pred, label)  # 计算support损失
                grads = tf.gradients(loss, list(self.weights.values()))  # 求梯度
                gradients = dict(zip(self.weights.keys(), grads))
                fast_weights = dict(zip(self.weights.keys(),
                                        [self.weights[key] - self.update_lr * gradients[key] for key in
                                         self.weights.keys()]))  # 得到更新后的权重
                pred_q = self._forward(image_q, fast_weights)  # 用support更新后的权重对query进行预测
                loss_q[0] += self._cal_loss(pred_q, label_q)  # 计算query损失
                for j in range(self.num_updates - 1):
                    pred = self._forward(image, fast_weights)
                    loss = self._cal_loss(pred, label)
                    grads = tf.gradients(loss, list(fast_weights.values()))
                    gradients = dict(zip(fast_weights.keys(), grads))
                    fast_weights = dict(zip(fast_weights.keys(),
                                            [fast_weights[key] - self.update_lr * gradients[key] for key
                                             in fast_weights.keys()]))
                    pred_q = self._forward(image_q, fast_weights)
                    loss_q[j + 1] += self._cal_loss(pred_q, label_q)
            return loss_q

        def meta_finetune():
            """元学习 用于测试"""
            output_q, error_q = [], []
            pred = self._forward(self._image, self.weights)
            loss = self._cal_loss(pred, self._label)
            grads = tf.gradients(loss, list(self.weights.values()))
            gradients = dict(zip(self.weights.keys(), grads))
            fast_weights = dict(zip(self.weights.keys(),
                                    [self.weights[key] - self.update_lr * gradients[key] for key in
                                     self.weights.keys()]))
            pred_q = self._forward(self._image_q, fast_weights)
            output_q.append(pred_q)
            error_q.append(self._eval(pred_q, self._label_q))

            for j in range(self.val_num_updates - 1):
                pred = self._forward(self._image, fast_weights)
                loss = self._cal_loss(pred, self._label)
                grads = tf.gradients(loss, list(fast_weights.values()))
                gradients = dict(zip(fast_weights.keys(), grads))
                fast_weights = dict(zip(fast_weights.keys(),
                                        [fast_weights[key] - self.update_lr * gradients[key] for key
                                         in fast_weights.keys()]))
                pred_q = self._forward(self._image_q, fast_weights)
                output_q.append(pred_q)
                error_q.append(self._eval(pred_q, self._label_q))
            return output_q, error_q

        loss_q = meta_learning()
        self.loss = loss_q[-1] / self.task_num

        output_q, error_q = meta_finetune()
        self.pred_q = output_q[-1]
        self.error = error_q[-1]

        self.global_step = tf.train.get_or_create_global_step()
        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        gvs = optimizer.compute_gradients(self.loss)
        self.train_op = optimizer.apply_gradients(gvs)

    def construct_weights(self):
        """ 用原始方法获取权重 """
        weights = {}
        # 卷积层
        conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=tf.float32)
        weights['conv1'] = tf.get_variable('conv1', [5, 5, 3, 32], initializer=conv_initializer, dtype=tf.float32)
        weights['convb1'] = tf.Variable(tf.zeros([32]), name='convb1')

        # 全连接层
        weights['w1'] = tf.Variable(tf.truncated_normal([96, 128], stddev=0.01), name='w1')
        weights['b1'] = tf.Variable(tf.zeros([128]), name='b1')
        return weights

    def _conv_block(self, inp, conv_w, conv_b, stride, scope, max_pool2d=True):
        with tf.variable_scope(scope):
            output = tf.nn.conv2d(inp, conv_w, stride, padding='SAME') + conv_b
            output = tf.nn.relu(output)
            if max_pool2d:
                output = tf.nn.max_pool(output, [2, 2], [1, 2, 2, 1], padding='SAME')
        return output

    def _forward_conv(self, inp, weights, scope='conv'):
        conv = self._conv_block(inp, weights['conv1'], weights['convb1'], [1, 2, 2, 1], scope)
        conv = layers.flatten(conv, scope="flatten")
        return conv

    def _forward_fc(self, inp, w, b, activation=tf.nn.relu, scope='fc'):
        with tf.variable_scope(scope):
            fc = tf.matmul(inp, w) + b
            fc = activation(fc)
        return fc

    def _forward(self, image, weights):
        feature = self._forward_conv(image, weights)
        fc = self._forward_fc(feature, weights['w1'], weights['b1'])
        return fc

    def train(self, support, query, num_train_steps):
        with tf.Session() as sess:
            # 训练模型
            for i in range(num_train_steps):
                image, label = support
                image_q, label_q = query
                feed_dict = {
                    self._image: image,
                    self._label: label,
                    self._image_q: image_q,
                    self._label_q: label_q
                }
                train_loss, _ = sess.run([self.loss, self.train_op], feed_dict=feed_dict)

            sess.close()

    def _eval(self, pred, label):
        pass

    def _cal_loss(self, pred, label):
        pass
  • 完成本次实验后的一些经验和想法:
    • 一些参数设置:task_num=5, batch_sz=10000,n_way=5, k_shot=5,k_query=20/25,learning_rate=0.001
    • 适用于分类越多(本实验76类),每类样本数量不均衡或样本数量少,这样的情况下训练效果可能会超过通用训练方式。
    • 时间成本较高,上线或落地有些困难。
    • 每次的结果有些不稳定。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值