Hung-Yi Lee homework[14]:Life Long Learning
一、Life Long Learning原理
Life Long Learning的意思是:机器首先学习了任务一,然后学习了任务二,此时机器同时掌握了任务一和任务二,如果机器在之后的时间中继续学习别的任务,机器就能够拥有更多的技能,理想状态下,机器可以无所不能。
要想实现Life Long Learning,需要解决以下几个问题:(1)如何在学习新知识时对旧知识进行保留;(2)在训练不同的任务时,如何进行知识的迁移‘(3)如何进行有效的模型扩张使模型更加符合当前实际情况而不浪费计算资源。
实际过程中,Life Long Learning容易碰到灾难性遗忘的现象,目前对于灾难性遗忘的解决方法有以下几个常见的做法:
- Dynamic Expansion:直接搞一批新的参数来学习新任务,单这样模型的参数会越来越多,往往需要搭配一些模型压缩的操作。
- Rehearsal:如果让新任务上的梯度能尽可能接近旧任务上的梯度,那就可以保留很大一部分旧知识。
- Regularization:加一些正则项来避免跟旧任务关联比较大的参数的更新幅度过大。这是因为大部分神经网络都是大规模参数中有部分参数对模型并无决定性的作用,因此正则化的方法是有用武之地的。
作业中需要用到的EWC和MAS实际上都是基于Regularization的方法。在非Life Long Learning的问题上,模型在任务A上训练完之后,直接拿去任务B上进行微调,而这种训练出来的模型并不能完成任务A了(因为出现了灾难性以往),但是当我们添加一个正则项(L2)之后,使任务B上训练完的参数不能离任务A上训练完的结果太远,这就是Regularization的基本思想。
Regularization中,直接加入L2正则项并没有考虑不同的参数对于任务的重要性,会使任务B的学习陷入瓶颈,所以在进行基于Regularization的方法时,需要计算每个参数 θ i \theta_i θi对任务A的重要性 Ω i \Omega_i Ωi,然后添加了正则项的损失函数就变成了 L ( θ ) = L B ( θ ) + λ 2 ∑ i Ω i ( θ i − θ A , i ∗ ) 2 L(\theta)=L_B(\theta)+\frac{\lambda}{2}\sum_i\Omega_i(\theta_i-\theta^*_{A,i})^2 L(θ)=LB(θ)+2λi∑Ωi(θi−θA,i∗)2
二、作业描述
在本次作业过程中,需要走过EWC和MAS这两个解决Life Long Learning问题的方法,并连用在三个不同的数据集上,先训练数据集SVHN,再训练数据集MNIST,最后训练数据集USPS,同时做出精确度跟最后评估的图片最近似的图片。
三、作业实现
因为本次作业强调的是lifelong learning的训练方法,并不需要叠加模型,所以在作业的实现过程中,我们始终使用同一个模型。
hw14.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torchvision
from torchvision import datasets, transforms
import numpy as np
import os
import random
from copy import deepcopy
import json
from core_define import *
from preprocess import *
# python报错--SSL: CERTIFICATE_VERIFY_FAILED 的解决办法
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 用来画图
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 保存模型和优化器,保存路径为model.ckpt和model.opt
def save_model(model, optimizer, store_model_path):
torch.save(model.state_dict(), f'{store_model_path}.ckpt')
torch.save(optimizer.state_dict(), f'{store_model_path}.opt')
return
# 载入模型和优化器
def load_model(model, optimizer, load_model_path):
print(f'Load model from {load_model_path}')
model.load_state_dict(torch.load(f'{load_model_path}.ckpt'))
optimizer.load_state_dict(torch.load(f'{load_model_path}.opt'))
return model, optimizer
# 建立模型,由于本作业需要经过三个不同的数据集,首先需要分别下载这三个数据集,此时就会遇到python报错--SSL: CERTIFICATE_VERIFY_FAILED,解决方案如上
def build_model(data_path, batch_size, learning_rate):
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
data = Data(data_path)
datasets = data.get_datasets()
tasks = []
for dataset in datasets:
tasks.append(Dataloader(dataset, batch_size))
return model, optimizer, tasks
# 正常的训练过程(对于那个6层全连接+relu激活的训练)
def normal_train(model, optimizer, task, total_epochs, summary_epochs):
model.train()
model.zero_grad()
ceriation = nn.CrossEntropyLoss()
losses = []
loss = 0.0
for epoch in range(summary_epochs):
imgs, labels = next(task.train_iter)
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
ce_loss = ceriation(outputs, labels)
optimizer.zero_grad()
ce_loss.backward()
optimizer.step()
loss += ce_loss.item()
if (epoch + 1) % 50 == 0:
loss = loss / 50
print("\r", "train task {} [{}] loss: {:.3f} ".format(task.name, (total_epochs + epoch + 1), loss),
end=" ")
losses.append(loss)
loss = 0.0
return model, optimizer, losses
# ewc训练
def ewc_train(model, optimizer, task, total_epochs, summary_epochs, ewc, lambda_ewc):
model.train()
model.zero_grad()
ceriation = nn.CrossEntropyLoss()
losses = []
loss = 0.0
for epoch in range(summary_epochs):
imgs, labels = next(task.train_iter)
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
ce_loss = ceriation(outputs, labels)
total_loss = ce_loss
ewc_loss = ewc.penalty(model)
total_loss += lambda_ewc * ewc_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
loss += total_loss.item()
if (epoch + 1) % 50 == 0:
loss = loss / 50
print("\r", "train task {} [{}] loss: {:.3f} ".format(task.name, (total_epochs + epoch + 1), loss),
end=" ")
losses.append(loss)
loss = 0.0
return model, optimizer, losses
# mas训练
def mas_train(model, optimizer, task, total_epochs, summary_epochs, mas_tasks, lambda_mas, alpha=0.8):
model.train()
model.zero_grad()
ceriation = nn.CrossEntropyLoss()
losses = []
loss = 0.0
for epoch in range(summary_epochs):
imgs, labels = next(task.train_iter)
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
ce_loss = ceriation(outputs, labels)
total_loss = ce_loss
mas_tasks.reverse()
if len(mas_tasks) > 1:
preprevious = 1 - alpha
scalars = [alpha, preprevious]
for mas, scalar in zip(mas_tasks[:2], scalars):
mas_loss = mas.penalty(model)
total_loss += lambda_mas * mas_loss * scalar
elif len(mas_tasks) == 1:
mas_loss = mas_tasks[0].penalty(model)
total_loss += lambda_mas * mas_loss
else:
pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
loss += total_loss.item()
if (epoch + 1) % 50 == 0:
loss = loss / 50
print("\r", "train task {} [{}] loss: {:.3f} ".format(task.name, (total_epochs + epoch + 1), loss),
end=" ")
losses.append(loss)
loss = 0.0
return model, optimizer, losses
# 验证
def val(model, task):
model.eval()
correct_cnt = 0
for imgs, labels in task.val_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, pred_label = torch.max(outputs.data, 1)
correct_cnt += (pred_label == labels.data).sum().item()
return correct_cnt / task.val_dataset_size
# 主训练程序,定义了如何将 normal_train & mas_train & ewc_train 结合起来
def train_process(model, optimizer, tasks, config):
task_loss, acc = {}, {}
for task_id, task in enumerate(tasks):
print('\n')
total_epochs = 0
task_loss[task.name] = []
acc[task.name] = []
if config.mode == 'basic' or task_id == 0:
while (total_epochs < config.num_epochs):
model, optimizer, losses = normal_train(model, optimizer, task, total_epochs, config.summary_epochs)
task_loss[task.name] += losses
for subtask in range(task_id + 1):
acc[tasks[subtask].name].append(val(model, tasks[subtask]))
total_epochs += config.summary_epochs
if total_epochs % config.store_epochs == 0 or total_epochs >= config.num_epochs:
save_model(model, optimizer, config.store_model_path)
if config.mode == 'ewc' and task_id > 0:
old_dataloaders = []
for old_task in range(task_id):
old_dataloaders += [tasks[old_task].val_loader]
ewc = EWC(model, old_dataloaders, device)
while (total_epochs < config.num_epochs):
model, optimizer, losses = ewc_train(model, optimizer, task, total_epochs, config.summary_epochs, ewc,
config.lifelong_coeff)
task_loss[task.name] += losses
for subtask in range(task_id + 1):
acc[tasks[subtask].name].append(val(model, tasks[subtask]))
total_epochs += config.summary_epochs
if total_epochs % config.store_epochs == 0 or total_epochs >= config.num_epochs:
save_model(model, optimizer, config.store_model_path)
if config.mode == 'mas' and task_id > 0:
old_dataloaders = []
mas_tasks = []
for old_task in range(task_id):
old_dataloaders += [tasks[old_task].val_loader]
mas = MAS(model, old_dataloaders, device)
mas_tasks += [mas]
while (total_epochs < config.num_epochs):
model, optimizer, losses = mas_train(model, optimizer, task, total_epochs, config.summary_epochs,
mas_tasks, config.lifelong_coeff)
task_loss[task.name] += losses
for subtask in range(task_id + 1):
acc[tasks[subtask].name].append(val(model, tasks[subtask]))
total_epochs += config.summary_epochs
if total_epochs % config.store_epochs == 0 or total_epochs >= config.num_epochs:
save_model(model, optimizer, config.store_model_path)
# pass语句其实就是空语句,下面这段代码其实只是为了程序的可读性而写的,删去也无所谓
if config.mode == 'scp' and task_id > 0:
pass
return task_loss, acc
def plot_result(mode_list, task1, task2, task3):
# 画线
count = 0
for reg_name in mode_list:
label = reg_name
with open(f'./{reg_name}_acc.txt', 'r') as f:
acc = json.load(f)
if count == 0:
color = 'red'
elif count == 1:
color = 'blue'
else:
color = 'purple'
ax1 = plt.subplot(3, 1, 1)
plt.plot(range(len(acc[task1])), acc[task1], color, label=label)
ax1.set_ylabel(task1)
ax2 = plt.subplot(3, 1, 2, sharex=ax1, sharey=ax1)
plt.plot(range(len(acc[task3]), len(acc[task1])), acc[task2], color, label=label)
ax2.set_ylabel(task2)
ax3 = plt.subplot(3, 1, 3, sharex=ax1, sharey=ax1)
ax3.set_ylabel(task3)
plt.plot(range(len(acc[task2]), len(acc[task1])), acc[task3], color, label=label)
count += 1
plt.ylim((0.02, 1.02))
plt.legend()
plt.show()
return
# 定义超参数
class configurations(object):
def __init__(self):
self.batch_size = 256
self.num_epochs = 10000
self.store_epochs = 250
self.summary_epochs = 250
self.learning_rate = 0.0005
self.load_model = False
self.store_model_path = "./model"
self.load_model_path = "./model"
self.data_path = "./data"
self.mode = None
self.lifelong_coeff = 0.5
if __name__ == '__main__':
mode_list = ['mas', 'ewc', 'basic']
coeff_list = [0, 0, 0]
config = configurations()
count = 0
for mode in mode_list:
config.mode = mode
config.lifelong_coeff = coeff_list[count]
print("{} training".format(config.mode))
model, optimizer, tasks = build_model(config.data_path, config.batch_size, config.learning_rate)
print("Finish build model")
if config.load_model:
model, optimizer = load_model(model, optimizer, config.load_model_path)
task_loss, acc = train_process(model, optimizer, tasks, config)
with open(f'./{config.mode}_acc.txt', 'w') as f:
json.dump(acc, f)
count += 1
mode_list = ['ewc', 'mas', 'basic']
plot_result(mode_list, 'SVHN', 'MNIST', 'USPS')
core_define.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torchvision
from torchvision import datasets, transforms
import numpy as np
import os
import random
from copy import deepcopy
import json
# 核心训练模型的结构——一个六层的全连接+relu激活
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(3*32*32, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.fc4 = nn.Linear(256, 128)
self.fc5 = nn.Linear(128, 128)
self.fc6 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(-1, 3*32*32)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
x = self.relu(x)
x = self.fc5(x)
x = self.relu(x)
x = self.fc6(x)
return x
# EWC
class EWC(object):
"""
@article{kirkpatrick2017overcoming,
title={Overcoming catastrophic forgetting in neural networks},
author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others},
journal={Proceedings of the national academy of sciences},
year={2017},
url={https://arxiv.org/abs/1612.00796}
}
"""
def __init__(self, model: nn.Module, dataloaders: list, device):
self.model = model
self.dataloaders = dataloaders
self.device = device
self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} # 抓出模型的所有参数
self._means = {} # 初始化 平均參數
self._precision_matrices = self._calculate_importance() # 产生 EWC 的 Fisher (F) 矩阵
for n, p in self.params.items():
self._means[n] = p.clone().detach() # 算出每个参数的平均 (用之前任务的资料算平均)
def _calculate_importance(self):
precision_matrices = {}
for n, p in self.params.items(): # 初始化 Fisher (F) 的矩阵(都进行补零的操作)
precision_matrices[n] = p.clone().detach().fill_(0)
self.model.eval()
dataloader_num = len(self.dataloaders)
number_data = sum([len(loader) for loader in self.dataloaders])
for dataloader in self.dataloaders:
for data in dataloader:
self.model.zero_grad()
input = data[0].to(self.device)
output = self.model(input).view(1, -1)
label = output.max(1)[1].view(-1)
# 产生 EWC 的Fishier(F)矩阵
loss = F.nll_loss(F.log_softmax(output, dim=1), label)
loss.backward()
for n, p in self.model.named_parameters():
precision_matrices[n].data += p.grad.data ** 2 / number_data
precision_matrices = {n: p for n, p in precision_matrices.items()}
return precision_matrices
def penalty(self, model: nn.Module):
loss = 0
for n, p in model.named_parameters():
_loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
loss += _loss.sum()
return loss
class MAS(object):
"""
@article{aljundi2017memory,
title={Memory Aware Synapses: Learning what (not) to forget},
author={Aljundi, Rahaf and Babiloni, Francesca and Elhoseiny, Mohamed and Rohrbach, Marcus and Tuytelaars, Tinne},
booktitle={ECCV},
year={2018},
url={https://eccv2018.org/openaccess/content_ECCV_2018/papers/Rahaf_Aljundi_Memory_Aware_Synapses_ECCV_2018_paper.pdf}
}
"""
def __init__(self, model: nn.Module, dataloaders: list, device):
self.model = model
self.dataloaders = dataloaders
self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} # 抓出模型的所有參數
self._means = {} # 初始化 平均參數
self.device = device
self._precision_matrices = self.calculate_importance() # 產生 MAS 的 Omega(Ω) 矩陣
for n, p in self.params.items():
self._means[n] = p.clone().detach()
def calculate_importance(self):
print('Computing MAS')
precision_matrices = {}
for n, p in self.params.items():
precision_matrices[n] = p.clone().detach().fill_(0) # 初始化 Omega(Ω) 矩陣(都補零)
self.model.eval()
dataloader_num = len(self.dataloaders)
num_data = sum([len(loader) for loader in self.dataloaders])
for dataloader in self.dataloaders:
for data in dataloader:
self.model.zero_grad()
output = self.model(data[0].to(self.device))
# 产生 MAS 的Ω矩阵
output.pow_(2)
loss = torch.sum(output, dim=1)
loss = loss.mean()
loss.backward()
for n, p in self.model.named_parameters():
# MAS和EWC的不同之处
precision_matrices[n].data += p.grad.abs() / num_data
precision_matrices = {n: p for n, p in precision_matrices.items()}
return precision_matrices
def penalty(self, model: nn.Module):
loss = 0
for n, p in model.named_parameters():
_loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
loss += _loss.sum()
return loss
preprocess.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torchvision
from torchvision import datasets, transforms
import numpy as np
import os
import random
from copy import deepcopy
import json
# 准备数据集——MNISTMSVHN,USPS
class Data():
def __init__(self, path):
transform = get_transform()
self.MNIST_dataset = datasets.MNIST(root=os.path.join(path, "MNIST"),
transform=transform,
train=True,
download=True)
self.SVHN_dataset = datasets.SVHN(root=os.path.join(path, "SVHN"),
transform=transform,
split='train',
download=True)
self.USPS_dataset = datasets.USPS(root=os.path.join(path, "USPS"),
transform=transform,
train=True,
download=True)
def get_datasets(self):
a = [(self.SVHN_dataset, "SVHN"), (self.MNIST_dataset, "MNIST"), (self.USPS_dataset, "USPS")]
return a
# 建立Dataloader
class Dataloader():
def __init__(self, dataset, batch_size, split_ratio=0.1):
self.dataset = dataset[0]
self.name = dataset[1]
train_sampler, val_sampler = self.split_dataset(split_ratio)
self.train_dataset_size = len(train_sampler)
self.val_dataset_size = len(val_sampler)
self.train_loader = data.DataLoader(self.dataset, batch_size=batch_size, sampler=train_sampler)
self.val_loader = data.DataLoader(self.dataset, batch_size=batch_size, sampler=val_sampler)
self.train_iter = self.infinite_iter()
def split_dataset(self, split_ratio):
data_size = len(self.dataset)
split = int(data_size * split_ratio)
indices = list(range(data_size))
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = sampler.SubsetRandomSampler(train_idx)
val_sampler = sampler.SubsetRandomSampler(valid_idx)
return train_sampler, val_sampler
def infinite_iter(self):
it = iter(self.train_loader)
while True:
try:
ret = next(it)
yield ret
except StopIteration:
it = iter(self.train_loader)
# 将MNIST从(1,28,28)转到(3,32,32)
# 将USPS从(1,16,16)转到(3,32,32)
class Convert2RGB(object):
def __init__(self, num_channel):
self.num_channel = num_channel
def __call__(self, img):
img_channel = img.size()[0]
img = torch.cat([img] * (self.num_channel - img_channel + 1), 0)
return img
class Pad(object):
def __init__(self, size, fill=0, padding_mode='constant'):
self.size = size
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img):
img_size = img.size()[1]
assert ((self.size - img_size) % 2 == 0)
padding = (self.size - img_size) // 2
padding = (padding, padding, padding, padding)
return F.pad(img, padding, self.padding_mode, self.fill)
def get_transform():
transform = transforms.Compose([transforms.ToTensor(),
Pad(32),
Convert2RGB(3),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
return transform