MAML-Pytorch代码学习分解

一、引言

在学习小样本学习之元学习实现方法中,遇见MAML(模型不可知元学习)算法,通过元学习入门必备:MAML(背景+论文解读+代码分析)_元学习算法代码-CSDN博客该博客的末尾转到https://zhuanlan.zhihu.com/p/343827171知乎可以看到代码的下载地址,本博客旨在分析学习该代码,记录MAML的实现过程。

其中数据集选用Omniglot数据集。

注意:请区分:

任务数task_num(batchsz)、样本类别数(1623 characters)、选取的样本类别数n_way、批次数(episodes)

二、MAML

Meta-Learning即“学习如何学习”,展示的是一种思想,例如某个学习率的大小可以不再是手动设置的超参数,而是通过模型学习得到。其中MAML指的是学习一组最优的初始化参数,该参数能够应对不同的新任务,在新任务上通过简单的几步训练即可获得很好的效果,因此MAML的核心就是通过某种训练方法,让模型的权重和偏置参数处于一种很容易在新任务上收敛的位置。

2.1 元学习数据集

MAML是元学习的一种,因此数据处理和元学习的数据处理是一致的。

元学习的数据集包括Training_Data训练资料和Testing_Data测试资料,其中Training_Data里面又包含Support_Set支持集和Query_Set查询集,Testing_Data也包含Support_Set支持集和Query_Set查询集。

其中Spt用来获取梯度信息,Qry用来计算损失。

Training_Data是为了让模型收敛在某个好的位置,但并不是在某个任务上(例如目标分类)准确率最高,而是在新的任务上(Testing_Data)能够很快的得到很好的效果。

浅绿色线和深绿色线表示两个task的Loss随着Model Parameter变化的曲线。

左图MAML力求通过Training_Data收敛到到phi所在的位置,使得其在Testing_Data中能够很快的通过spt收敛到theta1和theta2所在位置。其中L(\phi )=\sum_{n=1}^{N}l^{n}(\widehat{\theta }^{n}),n表示第n个任务,\widehat{\theta }^{n}表示第n个任务的参数,其中第n个任务的参数是由第n个任务的Spt得来,损失由Qry计算。区别于l^{n}({\phi})表示的是第n个任务在同一个参数phi。

与Pre-training的区别在于,其力求在每个task上表现综合最优,因此模型容易收敛在由于phi位置,使得l1+l2和最小。

2.2 数据集处理

Omniglot数据集包含了1623组手写characters,每组character包含20张105*105的0单通道图像,其中同一个character的20张图像是20个不同的人手写的同一个字。因为其有大量的类别且每个类别的图像数量较少,因此该数据集适合小样本模型训练。

针对小样本学习,数据集batch格式通常为(n_way, k_shot, k_query),分别表示该batch的类别数、支持集样本数、查询集样本数。

三、代码逐段分析

下载解压后如图所示的文件结构,如果没有数据集,可以后边运行代码自动下载。

3.1 omniglot_train.py(onniglot训练、定义参数的主函数)

代码入口:

if __name__ == '__main__':

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--epoch', type=int, help='epoch number', default=40000)
    argparser.add_argument('--n_way', type=int, help='n way', default=5)
    argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
    argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
    argparser.add_argument('--imgsz', type=int, help='imgsz', default=28)
    argparser.add_argument('--imgc', type=int, help='imgc', default=1)
    argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=32)
    argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
    argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.4)
    argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
    argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)

    args = argparser.parse_args()

运行该文件代码从此处开始执行。

task_num表示一个batch中包含task_num个task,每个task内包含n_way个类别,每个类别包含k_spt个支持即样本、k_qry个查询集样本。

meta_lr为模型参数最终(实际)更新的(outer)lr,update_lr为模型参数在每个任务训练时的更新(inner)lr。

update_step为模型通过Training_Data训练时,模型更新的次数,也就是inner loop中,每个task喂给模型喂几次。

update_step_test就是对应Testing_Data时相同任务喂该数据的次数。

进入main主函数

config = [
        ('conv2d', [64, 1, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 2, 2, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('flatten', []),
        ('linear', [args.n_way, 64])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

config定义了一个网络结构,('name', [out_ch, in_ch, kernelsz, kernelsz, stride, padding])。

其中

maml = Meta(args, config).to(device)

使得模型进入meta.py

3.2 meta.py

maml是Meta()的实例化。

class Meta(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args, config):
        """

        :param args:
        """
        super(Meta, self).__init__()

        self.update_lr = args.update_lr
        self.meta_lr = args.meta_lr
        self.n_way = args.n_way
        self.k_spt = args.k_spt
        self.k_qry = args.k_qry
        self.task_num = args.task_num
        self.update_step = args.update_step
        self.update_step_test = args.update_step_test


        self.net = Learner(config, args.imgc, args.imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)

其中

self.net = Learner(config, args.imgc, args.imgsz)

将config再次转到Learner类中,进入Learner()。

3.3 Learner.py(构建网络结构并实现前向传播)

class Learner(nn.Module):
    """

    """

    def __init__(self, config, imgc, imgsz):
        """

        :param config: network config file, type:list of (string, list)
        :param imgc: 1 or 3
        :param imgsz:  28 or 84
        """
        super(Learner, self).__init__()


        self.config = config

        # this dict contains all tensors needed to be optimized
        self.vars = nn.ParameterList() # vars即为所有的可训练参数w
        # running_mean and running_var
        self.vars_bn = nn.ParameterList() # vars_bn即为所有的可训练参数bn

        for i, (name, param) in enumerate(self.config):
            if name is 'conv2d':
                # [ch_out, ch_in, kernelsz, kernelsz]   # [输出通道数, 输入通道数, 卷积核大小x, 卷积核大小y]
                w = nn.Parameter(torch.ones(*param[:4]))  # python的解包操作,将param的前四个参数作为一个整体传入
                # gain=1 according to cbfin's implementation
                torch.nn.init.kaiming_normal_(w) # 用一个kaiming_normal的初始化方式对w进行初始化
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name is 'convt2d':
                # [ch_in, ch_out, kernelsz, kernelsz, stride, padding]
                w = nn.Parameter(torch.ones(*param[:4]))
                # gain=1 according to cbfin's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_in, ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[1])))

            elif name is 'linear':
                # [ch_out, ch_in]
                w = nn.Parameter(torch.ones(*param))
                # gain=1 according to cbfinn's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name is 'bn':
                # [ch_out]
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

                # must set requires_grad=False
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])


            elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
                          'flatten', 'reshape', 'leakyrelu', 'sigmoid']:
                continue
            else:
                raise NotImplementedError

 在__init__()中,第一次见这种先把模型的所有可训练参数放在一对列表中。好处是自己可以决定每个参数的初始化方法,本代码中使用的是torch.nn.init.kaiming_normal_(w)初始化方式。

self.vars = nn.ParameterList() # vars即为所有的可训练参数w
self.vars_bn = nn.ParameterList() # vars_bn即为所有的可训练参数bn

其中

w = nn.Parameter(torch.ones(*param[:4]))

表示定义了out_ch * in_ch * kernelsz * kernelsz个可训练的权重参数。

另外用.extend()函数可以将新列表解开追加在原列表后,区别于attend不解包作为一个整体追加。

self.vars_bn.extend([running_mean, running_var])

Learner类还重写了exter_repr()方法,便于打印信息,具体不讨论。

    def extra_repr(self):

重点是forword()方法:

    def forward(self, x, vars=None, bn_training=True):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
                x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                idx += 2
                bn_idx += 2

            elif name is 'flatten':
                # print(x.shape)
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])

            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)


        return x

在建立网络的同时,完成了输入x在初始化后的网络中前向传播的过程。

另外类还重写了梯度清零zero_grad()方法,暂时不管。

    def zero_grad(self, vars=None):
        ……

现在回到上一次Learner类实例化的位置meta.py中。

3.4 meta.py

定义了一个将梯度通过L2范数(欧几里得范数)裁剪,防止梯度爆炸的函数。

其中

L_{2}=||g||_{2}=\sqrt{\sum g_{i}^{2}}

    def clip_grad_by_norm_(self, grad, max_norm):

重点是前向传播函数,与后向传播训练更新参数息息相关,由于目前在Omniglot_tain.py中仅仅实例化maml = Meta(args, config).to(device),前向传播将与maml,目前还未执行到xxx=maml(……),所以继续返回到Omniglot_tain.py函数中,往下执行。

3.5 Omniglot_tain.py

# db_train即训练database  是专门处理数据集的OmniglotNShot类的实例化
    db_train = OmniglotNShot('omniglot',
                       batchsz=args.task_num,
                       n_way=args.n_way,
                       k_shot=args.k_spt,
                       k_query=args.k_qry,
                       imgsz=args.imgsz)

进入OmniglotNshot.py函数。

3.6 OmniglotNshot.py(从本地数据集中构建加载数据集,好难懂!)

class OmniglotNShot:

    def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz):
        """
        Different from mnistNShot, the
        :param root:
        :param batchsz: task num
        :param n_way:
        :param k_shot:
        :param k_qry:
        :param imgsz:
        """

        self.resize = imgsz
        if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
            # if root/data.npy does not exist, just download it
            self.x = Omniglot(root, download=True,
                              transform=transforms.Compose([lambda x: Image.open(x).convert('L'),
                                                            lambda x: x.resize((imgsz, imgsz)),
                                                            lambda x: np.reshape(x, (imgsz, imgsz, 1)),
                                                            lambda x: np.transpose(x, [2, 0, 1]),
                                                            lambda x: x/255.])
                              )

            temp = dict()  # {label1:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
            # self.x实际上是一个实例,当出现for循环迭代对象时,如果对象没有实现 __iter__ __next__ 迭代器协议,
            # Python的解释器就会去寻找__getitem__ 来迭代对象,如果连__getitem__ 都没有定义,
            # 这解释器就会报对象不是迭代器的错误。
            # 因此此处self.x通过__getitem__()不断迭代返回(img, target)  index从0到最大,超过后自动跳出循环
            for (img, label) in self.x:
                if label in temp.keys():
                    temp[label].append(img)
                else:
                    temp[label] = [img]

            self.x = []
            for label, imgs in temp.items():  # labels info deserted , each label contains 20imgs
                self.x.append(np.array(imgs))

            # as different class may have different number of imgs
            self.x = np.array(self.x).astype(np.float64)  # [[20 imgs],..., 1623 classes in total]
            # each character contains 20 imgs
            print('data shape:', self.x.shape)  # [1623, 20, 84, 84, 1] # 1623个类别,每个类别20张图
            temp = []  # Free memory
            # save all dataset into npy file.
            np.save(os.path.join(root, 'omniglot.npy'), self.x)
            print('write into omniglot.npy.')
        else:
            # if data.npy exists, just load it.
            self.x = np.load(os.path.join(root, 'omniglot.npy'))
            print('load from omniglot.npy.')

        # [1623, 20, 84, 84, 1]
        # TODO: can not shuffle here, we must keep training and test set distinct!
        self.x_train, self.x_test = self.x[:1200], self.x[1200:] # 0~1199个类别作为训练集,1200~1622个类别为测试集, 训练集再分为支持集和查询集

        # self.normalization()

        self.batchsz = batchsz
        self.n_cls = self.x.shape[0]  # 1623
        self.n_way = n_way  # n way
        self.k_shot = k_shot  # k shot
        self.k_query = k_query  # k query
        assert (k_shot + k_query) <=20

        # save pointer of current read batch in total cache
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {"train": self.x_train, "test": self.x_test}  # original data cached
        print("DB: train", self.x_train.shape, "test", self.x_test.shape)

        # self.datasets["train"]就是self.x_train就是self.x的前1200个类别  # [1200, 20, 84, 84, 1]
        # self.datasets["test"]就是self.x_test就是self.x的后423的类别     # [423, 20, 84, 84, 1]
        # datasets_cache是一个包含着train和test两个名称的字典,分别作为当前epoch的数据集的key键值
        self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]),  # current epoch data cached
                               "test": self.load_data_cache(self.datasets["test"])}

前半部分中,通过if判断本地是否存在omniglot.npy(是numpy格式的数据)文件来决定是否进行下载。第一次仅下载代码后,本地不存在该文件,接着由于

self.x = Omniglot(root, download=True,
                  transform=transforms.Compose([lambda x: Image.open(x).convert('L'),
                                                lambda x: x.resize((imgsz, imgsz)),
                                                lambda x: np.reshape(x, (imgsz, imgsz, 1)),
                                                lambda x: np.transpose(x, [2, 0, 1]),
                                                lambda x: x/255.])
                  )

进入Omniglot类。

3.7 Omniglot.py(继承data.Dataset,可以通过__getitem__()魔法函数迭代返回一张张数据,主要用于下载数据集和预处理数据)

class Omniglot(data.Dataset):
    urls = [
        'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
        'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'

    '''
    The items are (filename,category). The index of all the categories can be found in self.idx_classes
    Args:
    - root: the directory where the dataset will be stored
    - transform: how to transform the input
    - target_transform: how to transform the target
    - download: need to download the dataset
    '''

    def __init__(self, root, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        if not self._check_exists():
            if download:
                self.download()
            else:
                raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

        self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
        self.idx_classes = index_classes(self.all_items)

 检查路径是否存在,不存在则执行下面下载代码。暂不展示。

第一个重点是下面的两个方法。

self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
self.idx_classes = index_classes(self.all_items)
def find_classes(root_dir):
    retour = []
    for (root, dirs, files) in os.walk(root_dir):
        for f in files:
            if (f.endswith("png")):
                r = root.split('/')
                lr = len(r)
                retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
    print("== Found %d items " % len(retour))
    return retour

# 为每个类别创建一个索引值
# {'classA': 0, 'classB': 1, 'classC': 2}
def index_classes(items):
    idx = {}
    for i in items:
        if i[1] not in idx:
            idx[i[1]] = len(idx)
    print("== Found %d classes" % len(idx))
    return idx

find_classes()中,返回一个列表,列表中存放着每一张图片的(图片文件名, 图片类别, 图片在文件系统所在的实际目录路径)

os.walk() 是 Python 中的一个非常方便的函数,用于遍历目录和其子目录,并返回一个三元组(root, dirs, files)的生成器,按照目录树一层一层递进迭代,其中root为当前访问的路径(字符串)、dirs是一个列表,表示当前路径下的子文件夹的名字(不包含路径仅包含名字)、files是一个列表,表示当前路径下的文件名字(不包含路径仅包含名字)

index_classes()中,i[1]是图片所属的类别。通过此函数,生成一个字典,键key为i[1]即类别名,值value为i[1]首次出现的位次。绕过来绕过去就是为每个类别赋值一个数字来代表。

另外因为继承了data.Dataset类,重写了两个方法:

# 每次仅仅返回一张图片和标签
    def __getitem__(self, index):
        filename = self.all_items[index][0]
        img = str.join('/', [self.all_items[index][2], filename])

        target = self.idx_classes[self.all_items[index][1]] # 此时target即为数字索引了
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def __len__(self):
        return len(self.all_items)

这样用Omniglot类加载数据的时候,可以自动调用__getitem__()迭代返回(img, target)。

现在介绍完Omniglot.py回到OmniglotNShot.py。

3.8 OmniglotNShot.py

通过3.7的__getitem__()就可以解释3.6中的这段代码:(认真学习继承Dataset自动迭代的方法)

# self.x实际上是一个实例,当出现for循环迭代对象时,如果对象没有实现 __iter__ __next__ 迭代器协议,
# Python的解释器就会去寻找__getitem__ 来迭代对象,如果连__getitem__ 都没有定义,
# 这解释器就会报对象不是迭代器的错误。
# 因此此处self.x通过__getitem__()不断迭代返回(img, target)  index从0到最大,超过后自动跳出循环
for (img, label) in self.x:
    if label in temp.keys():
        temp[label].append(img)
    else:
        temp[label] = [img]

3.6中后续代码修改了self.x的数据格式(float64的np数组),最终得到一个超大列表,列表的shape为[1623, 20, 84, 84, 1],表示1623个类别,每个类别中20个样本,每个样本1通道,像素84*84。保存为omniglot.npy。

随后定义了

self.x_train, self.x_test用来划分数据集为Training_Data(前1200个类别)和Testing_Data(后423个类别)
self.indexes = {"train": 0, "test": 0} # 表示第几个batch,后边.next()函数中逐个batch返回数据用得到
self.datasets = {"train": self.x_train, "test": self.x_test} 

又定义了self.datasets_cache

# self.datasets["train"]就是self.x_train就是self.x的前1200个类别  # [1200, 20, 84, 84, 1]
# self.datasets["test"]就是self.x_test就是self.x的后423的类别     # [423, 20, 84, 84, 1]
# datasets_cache是一个包含着train和test两个名称的字典,分别作为当前epoch的数据集的key键值
self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]),
                       "test": self.load_data_cache(self.datasets["test"])}

其中的self.load_data_cache()方法:

    def load_data_cache(self, data_pack):   # data_pack == [[20imgs] [20imgs] [……(1200or423个子列表)]]
                                            # data_pack.shape == [1200 or 423, 20, 84, 84, 1]
        # print(data_pack.shape)            # (1200, 20, 1, 28, 28) or (423, 20, 1, 28, 28)
        """
        Collects several batches data for N-shot learning
        :param data_pack: [cls_num, 20, 84, 84, 1]
        :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
        """
        #  take 5 way 1 shot as example: 5 * 1
        setsz = self.k_shot * self.n_way
        querysz = self.k_query * self.n_way
        data_cache = []

        # print('preload next 50 caches of batchsz of batch.')
        for sample in range(10):  # num of episodes

            x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
            for i in range(self.batchsz):  # one batch means one set
                # 若 batchsz == 32(n), 则取32(n)个任务,每个任务有n_way个类别,每个类别k张spt和q张qry图
                x_spt, y_spt, x_qry, y_qry = [], [], [], []
                selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) # False表示不放回
                #selected_cls = [12, 45, 689, 123, 456, ...(n个索引值)]
                for j, cur_class in enumerate(selected_cls):

                    selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
                    # 从20个数字中挑出(支持集+查询集)张数的 数字索引 个数, 前k张作为sqt, 剩下的作为qrt
                    # meta-training(spt) and meta-test(qry)
                    x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]]) # 通过键值取字典
                    x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
                    y_spt.append([j for _ in range(self.k_shot)])  # 注意: j 实际上是selected_img列表的索引值
                    y_qry.append([j for _ in range(self.k_query)])
                    '''
                    在这个上下文中,cur_class 是数据集中类别的实际索引,可能对于外部的标签系统来说没有实际意义。
                    j 是当前类别在 selected_cls 中的索引,它为模型提供了一个更简单的标签系统,
                    其中每个任务中的每个类别都标记为 0 到 n_way - 1 之间的整数。这简化了标签系统,
                    并允许模型在每个任务中更容易地识别和区分类别。
                    '''
                # 至此 x_spt, y_spt, x_qry, y_qry 四个列表中包含了一个batch的数据,总共有batchsz个这样的数据构成一个episode,共有10个episode
                # 下面将抽取的顺序打乱
                # shuffle inside a batch
                perm = np.random.permutation(self.n_way * self.k_shot)
                x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
                y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
                perm = np.random.permutation(self.n_way * self.k_query)
                x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
                y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]

                # append 使得x_spt==[sptsz, 1, 84, 84] => x_spts==[b, setsz, 1, 84, 84]
                x_spts.append(x_spt)
                y_spts.append(y_spt)
                x_qrys.append(x_qry)
                y_qrys.append(y_qry)


            # [b, setsz, 1, 84, 84]
            x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
            y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz)
            # [b, qrysz, 1, 84, 84]
            x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
            y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)

            data_cache.append([x_spts, y_spts, x_qrys, y_qrys])

        return data_cache  # data_cache里装着10个列表,每个列表包含x_spts, y_spts, x_qrys, y_qrys四个子列表,
                                            # 每个子列表包含对应子 子列表,例如x_spts内包含task_num(batchsz)个x_spt
                                            # x_spts==[b, setsz, 1, 84, 84], setsz=n_way * k_shot

第一个for循环了10次表示10个eposides,也相当于10个batch批次。(普通深度学习中,假如有1000条数据,设置的batch_size=10则每个batch内含有100条数据,此元学习中,每个batch(task)内有n_way=5个类别,每个类别有k_shot+k_query=15张图片,所以eposide与batch个数有所区别,可以理解为抽取eposides次,不一定把所有数据抽完,会有重复且有未抽到的)。

第二个for循环了task_num次,表示取task_num个任务,每个任务有n_way个类别。

第三个for循环遍历抽出来的类别,对每个类别再抽k张spt和q张qry图。其中x_spt和x_qry列表分别存放抽到的图片的np数据,y_spt和y_qry列表分别存放抽到的图片的所属类别,值得注意的是,此处的类别并非实际1623个类别之一,而是n_way个类别之一(0 1 2 3 4),因为输出logits就是n_way维的

 如此,返回了shape为[b=10, spt_size, 1, 84, 84]的data_cache

另外,实际训练时会通过.next()方法返回单批次数据,在这里提前展示一下。

    def next(self, mode='train'): # 默认返回训练模式下sqt和qrt
        """
        Gets next batch from the dataset with name.
        :param mode: The name of the splitting (one of "train", "val", "test")
        :return:
        """
        # update cache if indexes is larger cached num
        if self.indexes[mode] >= len(self.datasets_cache[mode]):
            self.indexes[mode] = 0
            self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])

        next_batch = self.datasets_cache[mode][self.indexes[mode]] # self.datasets_cache[mode].shape == [b, setsz, 1, 84, 84]
        self.indexes[mode] += 1 # indexes表示第几个batch,其中1batch数据即一个episode个手机的数据,包括task_num个任务,每个任务n_way个类别

        return next_batch # 返回一个batch的数据  [x_spts, y_spts, x_qrys, y_qrys]

self.index[mode]表示该mode下第几个batch。[x_spts, y_spts, x_qrys, y_qrys],其中x_spts内包含task_num个任务,每个任务n_way个类别,每个类别setsz个图像。

数据的问题解决了,现在返回到Omniglot_train.py。

3.9 Omniglot_train.py

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

epoch循环。

每次调用next()返回一个batch的数据,注意并不是将整个data_cache列表均返回。一个data_cache内包含episodes个batch。

通过

accs = maml(x_spt, y_spt, x_qry, y_qry)

进入Meta类的前向传播计算。

3.10 meta.py 前向传播

注意:源代码中添加

y_spt = y_spt.long()
y_qry = y_qry.long()

否则计算交叉熵数据类型报错。

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        """

        :param x_spt:   [b, setsz, c_, h, w]
        :param y_spt:   [b, setsz]
        :param x_qry:   [b, querysz, c_, h, w]
        :param y_qry:   [b, querysz]
        :return:
        """
        y_spt = y_spt.long()
        y_qry = y_qry.long()
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]  # losses_q[i] is the loss on step i
        corrects = [0 for _ in range(self.update_step + 1)]


        for i in range(task_num):

            # 1. run the i-th task and compute loss for k=0
            logits = self.net(x_spt[i], vars=None, bn_training=True) # vars用于初始化网络参数, None表示用创建网络时的参数
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))
            # fast_weights=parameters−learning_rate×gradients
            '''
            grad = [Δw1, Δw2, ..., Δwn] (模型参数的梯度列表)
            self.net.parameters() = [w1, w2, ..., wn] (模型参数列表)
            zip(grad, self.net.parameters())==[(Δw1, w1), (Δw2, w2), ..., (Δwn, wn)]
            map函数将一个函数(lambda)应用于一个序列p的每个元素。这里,它将一个 lambda 函数应用于 zip 函数生成的元组列表。
            转变成list后fast_weights包含了更新后的所有参数列表
            fast_weights = [
                (w1 - self.update_lr * Δw1),
                (w2 - self.update_lr * Δw2),
                ...
                (wn - self.update_lr * Δwn)
            ]
            '''
            # this is the loss and accuracy before first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q

                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                # [setsz]
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                logits = self.net(x_spt[i], fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                # 2. compute grad on theta_pi
                grad = torch.autograd.grad(loss, fast_weights)
                # 3. theta_pi = theta_pi - train_lr * grad
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                # loss_q will be overwritten and just keep the loss_q on last update step.
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()  # convert to numpy
                    corrects[k + 1] = corrects[k + 1] + correct



        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_q = losses_q[-1] / task_num

        # optimize theta parameters
        self.meta_optim.zero_grad() # 梯度清零
        loss_q.backward()
        # print('meta update')
        # for p in self.net.parameters()[:5]:
        # 	print(torch.norm(p).item())
        self.meta_optim.step()


        accs = np.array(corrects) / (querysz * task_num)

        return accs

主要看for循环中,首先将x_spts的第i个任务的数据取出来,喂给net,计算了原始var下的logits,与y_spts[i]计算cross_entropy,计算grad,计算fast_weight = theta - update_lr * grad,这里用到了map和lambda函数,其中p=zip(grad, self.net.parameters())。

然后with no_grad():

losses_q[0]+=将x_qry喂入原始参数时,输出结果的loss

corrects[0]+=将x_qry喂入原始参数时,输出结果的正确个数

然后with no_grad():

losses_q[1]+=将x_qry喂入通过x_spt数据更新的第一次参数时,输出结果的loss

corrects[1]+=将x_qry喂入通过x_spt数据更新的第一次参数时,输出结果的正确个数

然后for k in range(1, self.update_step):

losses_q[1+1]+=将x_qry喂入通过x_spt数据更新的第二次参数时,输出结果的loss

corrects[1+1]+=将x_qry喂入通过x_spt数据更新的第二次参数时,输出结果的正确个数

……

k==self.update_step-1时,

losses_q[update_step]+=将x_qry喂入通过x_spt数据更新的第二次参数时,输出结果的loss

corrects[update_step]+=将x_qry喂入通过x_spt数据更新的第二次参数时,输出结果的正确个数

故该两个列表都是update_step+1维,依次表示原参数loss/correct_num、第一次支持集更新后查询集上的loss/correct_num、第二次支持集更新后查询集上的loss/correct_num、……、第update_step次支持集更新后查询集上的loss/correct_num。

遍历完所有任务后,由于记录数据时都是+=,所以最后统一 / task_num,得到平均的损失和正确率。

self.meta_optim.zero_grad() # 梯度清零
loss_q.backward()
self.meta_optim.step()

通过这个完成一次loss_q反向传播,并更新参数。

完成task_num次循环后,回到Omniglot_train.py。

3.11 Omniglot_train.py

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000//args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append( test_acc )

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)

每50个epoch打印一下acc,主要看第一个和最后一个值,分别代表初始参数下正确率和当前更新后参数的查询集准确率。

每500个epoch,进行一次Testing_Data测试。其中

# split to single task each time
for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
    test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
    accs.append( test_acc )

将数据分离,因此测试任务每次应该只有一个任务,n_way个类别。

回到meta.py的finetunning()方法。

3.12 meta.py(Testing_Data微调)

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        """

        :param x_spt:   [setsz, c_, h, w]
        :param y_spt:   [setsz]
        :param x_qry:   [querysz, c_, h, w]
        :param y_qry:   [querysz]
        :return:
        """
        y_spt = y_spt.long()
        y_qry = y_qry.long()
        assert len(x_spt.shape) == 4

        querysz = x_qry.size(0)

        corrects = [0 for _ in range(self.update_step_test + 1)]

        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        net = deepcopy(self.net)

        # 1. run the i-th task and compute loss for k=0
        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())))

        # this is the loss and accuracy before first update
        with torch.no_grad():
            # [setsz, nway]
            logits_q = net(x_qry, net.parameters(), bn_training=True)
            # [setsz]
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            # scalar
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] = corrects[0] + correct

        # this is the loss and accuracy after the first update
        with torch.no_grad():
            # [setsz, nway]
            logits_q = net(x_qry, fast_weights, bn_training=True)
            # [setsz]
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            # scalar
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[1] = corrects[1] + correct

        for k in range(1, self.update_step_test):
            # 1. run the i-th task and compute loss for k=1~K-1
            logits = net(x_spt, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, y_spt)
            # 2. compute grad on theta_pi
            grad = torch.autograd.grad(loss, fast_weights)
            # 3. theta_pi = theta_pi - train_lr * grad
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

            logits_q = net(x_qry, fast_weights, bn_training=True)
            # loss_q will be overwritten and just keep the loss_q on last update step.
            loss_q = F.cross_entropy(logits_q, y_qry)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry).sum().item()  # convert to numpy
                corrects[k + 1] = corrects[k + 1] + correct


        del net

        accs = np.array(corrects) / querysz

        return accs

微调和训练过程十分类似,区别在于微调仅输入单任务,并且初始参数是训练后所认为的最优参数。

由于微调是在原训练模型的基础上,并且要保证原最优参数不变,故需要通过

net = deepcopy(self.net)

复制一份模型。

另外还有就是update_step_test不同,一般大于update_step。因此一般新任务多训练几次效果会更好,只不过在找最优初始参数的训练时为了防止数据量小过拟合、提高训练速度而减少更新次数。

返回值accs列表中最后一个值表示模型参数通过update_step_test次该任务支持集的更新后,在该任务上的查询集上的准确率。

四、总结

经过5个py文件的12次相互调用后,该代码运行并解析完毕!

配置模型(独特的参数初始化方法)——下载数据——加载数据(加载数据较为复杂难懂,关键区分epoch、eposides、task_num=batchsz、n_way几个参数的实际意义)——训练(支持集更新、查询集计算loss和correct)——测试(单任务支持集更新、查询集计算loss和correct)。

以下是使用PyTorch实现的MAML学习的示例代码: ```python import torch import torch.nn as nn import torch.optim as optim class MAML(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(MAML, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x def clone(self, device=None): clone = MAML(self.input_size, self.hidden_size, self.output_size) if device is not None: clone.to(device) clone.load_state_dict(self.state_dict()) return clone class MetaLearner(nn.Module): def __init__(self, model, lr): super(MetaLearner, self).__init__() self.model = model self.optimizer = optim.Adam(self.model.parameters(), lr=lr) def forward(self, x): return self.model(x) def meta_update(self, task_gradients): for param, gradient in zip(self.model.parameters(), task_gradients): param.grad = gradient self.optimizer.step() self.optimizer.zero_grad() def train_task(model, data_loader, lr_inner, num_updates_inner): model.train() task_loss = 0.0 for i, (input, target) in enumerate(data_loader): input = input.to(device) target = target.to(device) clone = model.clone(device) meta_optimizer = MetaLearner(clone, lr_inner) for j in range(num_updates_inner): output = clone(input) loss = nn.functional.mse_loss(output, target) grad = torch.autograd.grad(loss, clone.parameters(), create_graph=True) fast_weights = [param - lr_inner * g for param, g in zip(clone.parameters(), grad)] clone.load_state_dict({name: param for name, param in zip(clone.state_dict(), fast_weights)}) output = clone(input) loss = nn.functional.mse_loss(output, target) task_loss += loss.item() grad = torch.autograd.grad(loss, model.parameters()) task_gradients = [-lr_inner * g for g in grad] meta_optimizer.meta_update(task_gradients) return task_loss / len(data_loader) # Example usage device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') input_size = 1 hidden_size = 20 output_size = 1 model = MAML(input_size, hidden_size, output_size) model.to(device) data_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.randn(100, input_size), torch.randn(100, output_size)), batch_size=10, shuffle=True) meta_optimizer = MetaLearner(model, lr=0.001) for i in range(100): task_loss = train_task(model, data_loader, lr_inner=0.01, num_updates_inner=5) print('Task loss:', task_loss) meta_optimizer.zero_grad() task_gradients = torch.autograd.grad(task_loss, model.parameters()) meta_optimizer.meta_update(task_gradients) ``` 在这个示例中,我们定义了两个类,MAML和MetaLearner。MAML是一个普通的神经网络,而MetaLearner包含了用于更新MAML的元优化器。在每个任务上,我们使用MAML的副本进行内部更新,然后使用元优化器来更新MAML的权重。在元学习的过程中,我们首先通过调用train_task函数来训练一个任务,然后通过调用meta_update函数来更新MAML的权重。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值