Fisher信息矩阵(Fisher Information Matrix, FIM)是一种统计学概念,用于描述一个概率分布模型的参数对模型输出的敏感程度。在深度学习中,Fisher信息矩阵用于衡量模型参数的不确定性。Fisher信息矩阵在解决灾难性遗忘问题时起到了关键作用,因为它可以帮助我们了解哪些参数在之前的任务中扮演了重要角色,从而在学习新任务时对这些参数施加更大的约束。
在EWC算法中,Fisher信息矩阵的具体体现是一个矩阵,其维度与模型参数相同。Fisher信息矩阵的每个元素表示相应参数的不确定性。较大的Fisher信息值意味着参数在先前任务中的重要性较高,因此在学习新任务时应该施加更大的约束来防止对这些参数进行过度更新。
为了计算Fisher信息矩阵,我们首先对模型进行训练,然后计算每个参数的梯度平方。梯度平方与模型输出的敏感性成正比,因此可以用作Fisher信息矩阵的近似值。计算Fisher信息矩阵时,我们需要遍历整个训练数据集,计算每个数据点的梯度平方,然后求平均值。
最后,在训练新任务时,我们使用Fisher信息矩阵计算EWC惩罚项,该惩罚项与参数变化量的平方成正比。这个惩罚项会在训练过程中添加到原始损失函数中,从而减缓对关键参数的更新,保留在先前任务中学到的知识。
# EWC implementation
class EWC:
def __init__(self, model, dataloader, device, importance=1000):
self.model = model
self.importance = importance
self.device = device
self.params = {n: p.clone().detach() for n, p in self.model.named_parameters() if p.requires_grad}
self.fisher = self._compute_fisher(dataloader)
def _compute_fisher(self, dataloader):
fisher = {}
for n, p in self.model.named_parameters():
if p.requires_grad:
fisher[n] = torch.zeros_like(p.data)
self.model.train()
for data, target in dataloader:
data, target = data.to(self.device), target.to(self.device)
self.model.zero_grad()
output = F.log_softmax(self.model(data), dim=1)
loss = F.nll_loss(output, target)
loss.backward()
for n, p in self.model.named_parameters():
if p.requires_grad:
fisher[n] += (p.grad ** 2) / len(dataloader)
return fisher
def penalty(self, new_model):
loss = 0
for n, p in new_model.named_parameters():
if p.requires_grad:
_loss = self.fisher[n] * (p - self.params[n]) ** 2
loss += _loss.sum()
return loss * (self.importance / 2)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
# Data preparation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
#task1_data = [data for data in train_dataset if data[1] < 5]
#task2_data = [data for data in train_dataset if data[1] >= 5]
# Split data into two groups
train_dataset_size = len(train_dataset)
train_split_sizes = [train_dataset_size // 2, train_dataset_size - train_dataset_size // 2]
task1_data, task2_data = random_split(train_dataset, train_split_sizes)
task1_loader = DataLoader(task1_data, batch_size=64, shuffle=True)
task2_loader = DataLoader(task2_data, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
#task1_test_data = [data for data in test_dataset if data[1] < 5]
#task2_test_data = [data for data in test_dataset if data[1] >= 5]
test_dataset_size = len(test_dataset)
test_split_sizes = [test_dataset_size // 2, test_dataset_size - test_dataset_size // 2]
task1_test_data, task2_test_data = random_split(test_dataset, test_split_sizes)
task1_test_loader = DataLoader(task1_test_data, batch_size=64, shuffle=False)
task2_test_loader = DataLoader(task2_test_data, batch_size=64, shuffle=False)
# Model definition
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# EWC implementation
class EWC:
def __init__(self, model, dataloader, device, importance=1000):
self.model = model
self.importance = importance
self.device = device
self.params = {n: p.clone().detach() for n, p in self.model.named_parameters() if p.requires_grad}
self.fisher = self._compute_fisher(dataloader)
def _compute_fisher(self, dataloader):
fisher = {}
for n, p in self.model.named_parameters():
if p.requires_grad:
fisher[n] = torch.zeros_like(p.data)
self.model.train()
for data, target in dataloader:
data, target = data.to(self.device), target.to(self.device)
self.model.zero_grad()
output = F.log_softmax(self.model(data), dim=1)
loss = F.nll_loss(output, target)
loss.backward()
for n, p in self.model.named_parameters():
if p.requires_grad:
fisher[n] += (p.grad ** 2) / len(dataloader)
return fisher
def penalty(self, new_model):
loss = 0
for n, p in new_model.named_parameters():
if p.requires_grad:
_loss = self.fisher[n] * (p - self.params[n]) ** 2
loss += _loss.sum()
return loss * (self.importance / 2)
# Train function
def train(model, dataloader, optimizer, criterion, device, ewc=None, ewc_lambda=0.5):
model.train()
for data, target in dataloader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
if ewc is not None:
ewc_loss = ewc.penalty(model)
loss += ewc_lambda * ewc_loss
loss.backward()
optimizer.step()
# Test function
def test(model, dataloader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
return accuracy
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize model
model = SimpleNet().to(device)
# Train on Task 1
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
train(model, task1_loader, optimizer, criterion, device)
task1_accuracy = test(model, task1_test_loader, device)
print(f'Task 1 accuracy: {task1_accuracy}%')
# Save EWC
ewc = EWC(model, task1_loader, device)
# Train on Task 2
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(10):
train(model, task2_loader, optimizer, criterion, device, ewc=ewc, ewc_lambda=10 )
task2_accuracy = test(model, task2_test_loader, device)
print(f'Task 2 accuracy: {task2_accuracy}%')
task1_accuracy_new = test(model, task1_test_loader, device)
print(f'Tasknew 1 accuracy: {task1_accuracy_new}%')
task2_accuracy_NEW = test(model, task2_test_loader, device)
print(f'Tasknew 2 accuracy: {task2_accuracy_NEW}%')