Hung-Yi Lee homework[14]:Life Long Learning


一、Life Long Learning原理

  Life Long Learning的意思是:机器首先学习了任务一,然后学习了任务二,此时机器同时掌握了任务一和任务二,如果机器在之后的时间中继续学习别的任务,机器就能够拥有更多的技能,理想状态下,机器可以无所不能。
  要想实现Life Long Learning,需要解决以下几个问题:(1)如何在学习新知识时对旧知识进行保留;(2)在训练不同的任务时,如何进行知识的迁移‘(3)如何进行有效的模型扩张使模型更加符合当前实际情况而不浪费计算资源。
  实际过程中,Life Long Learning容易碰到灾难性遗忘的现象,目前对于灾难性遗忘的解决方法有以下几个常见的做法:

  • Dynamic Expansion:直接搞一批新的参数来学习新任务,单这样模型的参数会越来越多,往往需要搭配一些模型压缩的操作。
  • Rehearsal:如果让新任务上的梯度能尽可能接近旧任务上的梯度,那就可以保留很大一部分旧知识。
  • Regularization:加一些正则项来避免跟旧任务关联比较大的参数的更新幅度过大。这是因为大部分神经网络都是大规模参数中有部分参数对模型并无决定性的作用,因此正则化的方法是有用武之地的。
      作业中需要用到的EWC和MAS实际上都是基于Regularization的方法。在非Life Long Learning的问题上,模型在任务A上训练完之后,直接拿去任务B上进行微调,而这种训练出来的模型并不能完成任务A了(因为出现了灾难性以往),但是当我们添加一个正则项(L2)之后,使任务B上训练完的参数不能离任务A上训练完的结果太远,这就是Regularization的基本思想。
      Regularization中,直接加入L2正则项并没有考虑不同的参数对于任务的重要性,会使任务B的学习陷入瓶颈,所以在进行基于Regularization的方法时,需要计算每个参数 θ i \theta_i θi对任务A的重要性 Ω i \Omega_i Ωi,然后添加了正则项的损失函数就变成了 L ( θ ) = L B ( θ ) + λ 2 ∑ i Ω i ( θ i − θ A , i ∗ ) 2 L(\theta)=L_B(\theta)+\frac{\lambda}{2}\sum_i\Omega_i(\theta_i-\theta^*_{A,i})^2 L(θ)=LB(θ)+2λiΩi(θiθA,i)2

二、作业描述

  在本次作业过程中,需要走过EWC和MAS这两个解决Life Long Learning问题的方法,并连用在三个不同的数据集上,先训练数据集SVHN,再训练数据集MNIST,最后训练数据集USPS,同时做出精确度跟最后评估的图片最近似的图片。

三、作业实现

  因为本次作业强调的是lifelong learning的训练方法,并不需要叠加模型,所以在作业的实现过程中,我们始终使用同一个模型。

hw14.py

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torchvision
from torchvision import datasets, transforms

import numpy as np
import os
import random
from copy import deepcopy
import json

from core_define import *
from preprocess import *

# python报错--SSL: CERTIFICATE_VERIFY_FAILED 的解决办法
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# 用来画图
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 保存模型和优化器,保存路径为model.ckpt和model.opt
def save_model(model, optimizer, store_model_path):
    torch.save(model.state_dict(), f'{store_model_path}.ckpt')
    torch.save(optimizer.state_dict(), f'{store_model_path}.opt')
    return

# 载入模型和优化器
def load_model(model, optimizer, load_model_path):
    print(f'Load model from {load_model_path}')
    model.load_state_dict(torch.load(f'{load_model_path}.ckpt'))
    optimizer.load_state_dict(torch.load(f'{load_model_path}.opt'))
    return model, optimizer

# 建立模型,由于本作业需要经过三个不同的数据集,首先需要分别下载这三个数据集,此时就会遇到python报错--SSL: CERTIFICATE_VERIFY_FAILED,解决方案如上
def build_model(data_path, batch_size, learning_rate):
    model = Model().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    data = Data(data_path)
    datasets = data.get_datasets()
    tasks = []
    for dataset in datasets:
        tasks.append(Dataloader(dataset, batch_size))

    return model, optimizer, tasks

# 正常的训练过程(对于那个6层全连接+relu激活的训练)
def normal_train(model, optimizer, task, total_epochs, summary_epochs):
    model.train()
    model.zero_grad()
    ceriation = nn.CrossEntropyLoss()
    losses = []
    loss = 0.0
    for epoch in range(summary_epochs):
        imgs, labels = next(task.train_iter)
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        ce_loss = ceriation(outputs, labels)

        optimizer.zero_grad()
        ce_loss.backward()
        optimizer.step()

        loss += ce_loss.item()
        if (epoch + 1) % 50 == 0:
            loss = loss / 50
            print("\r", "train task {} [{}] loss: {:.3f}      ".format(task.name, (total_epochs + epoch + 1), loss),
                  end=" ")
            losses.append(loss)
            loss = 0.0

    return model, optimizer, losses

# ewc训练
def ewc_train(model, optimizer, task, total_epochs, summary_epochs, ewc, lambda_ewc):
    model.train()
    model.zero_grad()
    ceriation = nn.CrossEntropyLoss()
    losses = []
    loss = 0.0
    for epoch in range(summary_epochs):
        imgs, labels = next(task.train_iter)
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        ce_loss = ceriation(outputs, labels)
        total_loss = ce_loss
        ewc_loss = ewc.penalty(model)
        total_loss += lambda_ewc * ewc_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        loss += total_loss.item()
        if (epoch + 1) % 50 == 0:
            loss = loss / 50
            print("\r", "train task {} [{}] loss: {:.3f}      ".format(task.name, (total_epochs + epoch + 1), loss),
                  end=" ")
            losses.append(loss)
            loss = 0.0

    return model, optimizer, losses

# mas训练
def mas_train(model, optimizer, task, total_epochs, summary_epochs, mas_tasks, lambda_mas, alpha=0.8):
    model.train()
    model.zero_grad()
    ceriation = nn.CrossEntropyLoss()
    losses = []
    loss = 0.0
    for epoch in range(summary_epochs):
        imgs, labels = next(task.train_iter)
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        ce_loss = ceriation(outputs, labels)
        total_loss = ce_loss
        mas_tasks.reverse()
        if len(mas_tasks) > 1:
            preprevious = 1 - alpha
            scalars = [alpha, preprevious]
            for mas, scalar in zip(mas_tasks[:2], scalars):
                mas_loss = mas.penalty(model)
                total_loss += lambda_mas * mas_loss * scalar
        elif len(mas_tasks) == 1:
            mas_loss = mas_tasks[0].penalty(model)
            total_loss += lambda_mas * mas_loss
        else:
            pass

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        loss += total_loss.item()
        if (epoch + 1) % 50 == 0:
            loss = loss / 50
            print("\r", "train task {} [{}] loss: {:.3f}      ".format(task.name, (total_epochs + epoch + 1), loss),
                  end=" ")
            losses.append(loss)
            loss = 0.0

    return model, optimizer, losses

# 验证
def val(model, task):
    model.eval()
    correct_cnt = 0
    for imgs, labels in task.val_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        _, pred_label = torch.max(outputs.data, 1)

        correct_cnt += (pred_label == labels.data).sum().item()

    return correct_cnt / task.val_dataset_size

# 主训练程序,定义了如何将 normal_train & mas_train & ewc_train 结合起来
def train_process(model, optimizer, tasks, config):
    task_loss, acc = {}, {}
    for task_id, task in enumerate(tasks):
        print('\n')
        total_epochs = 0
        task_loss[task.name] = []
        acc[task.name] = []
        if config.mode == 'basic' or task_id == 0:
            while (total_epochs < config.num_epochs):
                model, optimizer, losses = normal_train(model, optimizer, task, total_epochs, config.summary_epochs)
                task_loss[task.name] += losses

                for subtask in range(task_id + 1):
                    acc[tasks[subtask].name].append(val(model, tasks[subtask]))

                total_epochs += config.summary_epochs
                if total_epochs % config.store_epochs == 0 or total_epochs >= config.num_epochs:
                    save_model(model, optimizer, config.store_model_path)

        if config.mode == 'ewc' and task_id > 0:
            old_dataloaders = []
            for old_task in range(task_id):
                old_dataloaders += [tasks[old_task].val_loader]
            ewc = EWC(model, old_dataloaders, device)
            while (total_epochs < config.num_epochs):
                model, optimizer, losses = ewc_train(model, optimizer, task, total_epochs, config.summary_epochs, ewc,
                                                     config.lifelong_coeff)
                task_loss[task.name] += losses

                for subtask in range(task_id + 1):
                    acc[tasks[subtask].name].append(val(model, tasks[subtask]))

                total_epochs += config.summary_epochs
                if total_epochs % config.store_epochs == 0 or total_epochs >= config.num_epochs:
                    save_model(model, optimizer, config.store_model_path)

        if config.mode == 'mas' and task_id > 0:
            old_dataloaders = []
            mas_tasks = []
            for old_task in range(task_id):
                old_dataloaders += [tasks[old_task].val_loader]
                mas = MAS(model, old_dataloaders, device)
                mas_tasks += [mas]
            while (total_epochs < config.num_epochs):
                model, optimizer, losses = mas_train(model, optimizer, task, total_epochs, config.summary_epochs,
                                                     mas_tasks, config.lifelong_coeff)
                task_loss[task.name] += losses

                for subtask in range(task_id + 1):
                    acc[tasks[subtask].name].append(val(model, tasks[subtask]))

                total_epochs += config.summary_epochs
                if total_epochs % config.store_epochs == 0 or total_epochs >= config.num_epochs:
                    save_model(model, optimizer, config.store_model_path)

        # pass语句其实就是空语句,下面这段代码其实只是为了程序的可读性而写的,删去也无所谓
        if config.mode == 'scp' and task_id > 0:
            pass
    return task_loss, acc


def plot_result(mode_list, task1, task2, task3):
    # 画线
    count = 0
    for reg_name in mode_list:
        label = reg_name
        with open(f'./{reg_name}_acc.txt', 'r') as f:
            acc = json.load(f)
        if count == 0:
            color = 'red'
        elif count == 1:
            color = 'blue'
        else:
            color = 'purple'
        ax1 = plt.subplot(3, 1, 1)
        plt.plot(range(len(acc[task1])), acc[task1], color, label=label)
        ax1.set_ylabel(task1)
        ax2 = plt.subplot(3, 1, 2, sharex=ax1, sharey=ax1)
        plt.plot(range(len(acc[task3]), len(acc[task1])), acc[task2], color, label=label)
        ax2.set_ylabel(task2)
        ax3 = plt.subplot(3, 1, 3, sharex=ax1, sharey=ax1)
        ax3.set_ylabel(task3)
        plt.plot(range(len(acc[task2]), len(acc[task1])), acc[task3], color, label=label)
        count += 1
    plt.ylim((0.02, 1.02))
    plt.legend()
    plt.show()
    return

# 定义超参数
class configurations(object):
    def __init__(self):
        self.batch_size = 256
        self.num_epochs = 10000
        self.store_epochs = 250
        self.summary_epochs = 250
        self.learning_rate = 0.0005
        self.load_model = False
        self.store_model_path = "./model"
        self.load_model_path = "./model"
        self.data_path = "./data"
        self.mode = None
        self.lifelong_coeff = 0.5

if __name__ == '__main__':
    mode_list = ['mas', 'ewc', 'basic']
    coeff_list = [0, 0, 0]

    config = configurations()
    count = 0
    for mode in mode_list:
        config.mode = mode
        config.lifelong_coeff = coeff_list[count]
        print("{} training".format(config.mode))
        model, optimizer, tasks = build_model(config.data_path, config.batch_size, config.learning_rate)
        print("Finish build model")
        if config.load_model:
            model, optimizer = load_model(model, optimizer, config.load_model_path)
        task_loss, acc = train_process(model, optimizer, tasks, config)
        with open(f'./{config.mode}_acc.txt', 'w') as f:
            json.dump(acc, f)
        count += 1

    mode_list = ['ewc', 'mas', 'basic']
    plot_result(mode_list, 'SVHN', 'MNIST', 'USPS')

core_define.py

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torchvision
from torchvision import datasets, transforms

import numpy as np
import os
import random
from copy import deepcopy
import json

# 核心训练模型的结构——一个六层的全连接+relu激活
class Model(nn.Module):

  def __init__(self):
    super(Model, self).__init__()
    self.fc1 = nn.Linear(3*32*32, 1024)
    self.fc2 = nn.Linear(1024, 512)
    self.fc3 = nn.Linear(512, 256)
    self.fc4 = nn.Linear(256, 128)
    self.fc5 = nn.Linear(128, 128)
    self.fc6 = nn.Linear(128, 10)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = x.view(-1, 3*32*32)
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    x = self.relu(x)
    x = self.fc3(x)
    x = self.relu(x)
    x = self.fc4(x)
    x = self.relu(x)
    x = self.fc5(x)
    x = self.relu(x)
    x = self.fc6(x)
    return x

# EWC
class EWC(object):
  """
    @article{kirkpatrick2017overcoming,
        title={Overcoming catastrophic forgetting in neural networks},
        author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others},
        journal={Proceedings of the national academy of sciences},
        year={2017},
        url={https://arxiv.org/abs/1612.00796}
    }
  """

  def __init__(self, model: nn.Module, dataloaders: list, device):

    self.model = model
    self.dataloaders = dataloaders
    self.device = device

    self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}  # 抓出模型的所有参数
    self._means = {}  # 初始化 平均參數
    self._precision_matrices = self._calculate_importance()  # 产生 EWC 的 Fisher (F) 矩阵

    for n, p in self.params.items():
      self._means[n] = p.clone().detach()  # 算出每个参数的平均 (用之前任务的资料算平均)

  def _calculate_importance(self):
    precision_matrices = {}
    for n, p in self.params.items():  # 初始化 Fisher (F) 的矩阵(都进行补零的操作)
      precision_matrices[n] = p.clone().detach().fill_(0)

    self.model.eval()
    dataloader_num = len(self.dataloaders)
    number_data = sum([len(loader) for loader in self.dataloaders])
    for dataloader in self.dataloaders:
      for data in dataloader:
        self.model.zero_grad()
        input = data[0].to(self.device)
        output = self.model(input).view(1, -1)
        label = output.max(1)[1].view(-1)

        # 产生 EWC 的Fishier(F)矩阵
        loss = F.nll_loss(F.log_softmax(output, dim=1), label)
        loss.backward()

        for n, p in self.model.named_parameters():
          precision_matrices[n].data += p.grad.data ** 2 / number_data

    precision_matrices = {n: p for n, p in precision_matrices.items()}
    return precision_matrices

  def penalty(self, model: nn.Module):
    loss = 0
    for n, p in model.named_parameters():
      _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
      loss += _loss.sum()
    return loss


class MAS(object):
  """
  @article{aljundi2017memory,
    title={Memory Aware Synapses: Learning what (not) to forget},
    author={Aljundi, Rahaf and Babiloni, Francesca and Elhoseiny, Mohamed and Rohrbach, Marcus and Tuytelaars, Tinne},
    booktitle={ECCV},
    year={2018},
    url={https://eccv2018.org/openaccess/content_ECCV_2018/papers/Rahaf_Aljundi_Memory_Aware_Synapses_ECCV_2018_paper.pdf}
  }
  """

  def __init__(self, model: nn.Module, dataloaders: list, device):
    self.model = model
    self.dataloaders = dataloaders
    self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}  # 抓出模型的所有參數
    self._means = {}  # 初始化 平均參數
    self.device = device
    self._precision_matrices = self.calculate_importance()  # 產生 MAS 的 Omega(Ω) 矩陣

    for n, p in self.params.items():
      self._means[n] = p.clone().detach()

  def calculate_importance(self):
    print('Computing MAS')

    precision_matrices = {}
    for n, p in self.params.items():
      precision_matrices[n] = p.clone().detach().fill_(0)  # 初始化 Omega(Ω) 矩陣(都補零)

    self.model.eval()
    dataloader_num = len(self.dataloaders)
    num_data = sum([len(loader) for loader in self.dataloaders])
    for dataloader in self.dataloaders:
      for data in dataloader:
        self.model.zero_grad()
        output = self.model(data[0].to(self.device))

        # 产生 MAS 的Ω矩阵
        output.pow_(2)
        loss = torch.sum(output, dim=1)
        loss = loss.mean()
        loss.backward()

        for n, p in self.model.named_parameters():
          # MAS和EWC的不同之处
          precision_matrices[n].data += p.grad.abs() / num_data

    precision_matrices = {n: p for n, p in precision_matrices.items()}
    return precision_matrices

  def penalty(self, model: nn.Module):
    loss = 0
    for n, p in model.named_parameters():
      _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
      loss += _loss.sum()
    return loss

preprocess.py

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torchvision
from torchvision import datasets, transforms

import numpy as np
import os
import random
from copy import deepcopy
import json

# 准备数据集——MNISTMSVHN,USPS
class Data():

    def __init__(self, path):
        transform = get_transform()

        self.MNIST_dataset = datasets.MNIST(root=os.path.join(path, "MNIST"),
                                            transform=transform,
                                            train=True,
                                            download=True)

        self.SVHN_dataset = datasets.SVHN(root=os.path.join(path, "SVHN"),
                                          transform=transform,
                                          split='train',
                                          download=True)

        self.USPS_dataset = datasets.USPS(root=os.path.join(path, "USPS"),
                                          transform=transform,
                                          train=True,
                                          download=True)

    def get_datasets(self):
        a = [(self.SVHN_dataset, "SVHN"), (self.MNIST_dataset, "MNIST"), (self.USPS_dataset, "USPS")]
        return a


# 建立Dataloader
class Dataloader():

    def __init__(self, dataset, batch_size, split_ratio=0.1):
        self.dataset = dataset[0]
        self.name = dataset[1]
        train_sampler, val_sampler = self.split_dataset(split_ratio)

        self.train_dataset_size = len(train_sampler)
        self.val_dataset_size = len(val_sampler)

        self.train_loader = data.DataLoader(self.dataset, batch_size=batch_size, sampler=train_sampler)
        self.val_loader = data.DataLoader(self.dataset, batch_size=batch_size, sampler=val_sampler)
        self.train_iter = self.infinite_iter()

    def split_dataset(self, split_ratio):
        data_size = len(self.dataset)
        split = int(data_size * split_ratio)
        indices = list(range(data_size))
        np.random.shuffle(indices)
        train_idx, valid_idx = indices[split:], indices[:split]

        train_sampler = sampler.SubsetRandomSampler(train_idx)
        val_sampler = sampler.SubsetRandomSampler(valid_idx)
        return train_sampler, val_sampler

    def infinite_iter(self):
        it = iter(self.train_loader)
        while True:
            try:
                ret = next(it)
                yield ret
            except StopIteration:
                it = iter(self.train_loader)

# 将MNIST从(1,28,28)转到(3,32,32)
# 将USPS从(1,16,16)转到(3,32,32class Convert2RGB(object):

    def __init__(self, num_channel):
        self.num_channel = num_channel

    def __call__(self, img):
        img_channel = img.size()[0]
        img = torch.cat([img] * (self.num_channel - img_channel + 1), 0)
        return img


class Pad(object):

    def __init__(self, size, fill=0, padding_mode='constant'):
        self.size = size
        self.fill = fill
        self.padding_mode = padding_mode

    def __call__(self, img):
        img_size = img.size()[1]
        assert ((self.size - img_size) % 2 == 0)
        padding = (self.size - img_size) // 2
        padding = (padding, padding, padding, padding)
        return F.pad(img, padding, self.padding_mode, self.fill)


def get_transform():
    transform = transforms.Compose([transforms.ToTensor(),
                                    Pad(32),
                                    Convert2RGB(3),
                                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    return transform

实验结果

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值