【无标题】

Fisher信息矩阵(Fisher Information Matrix, FIM)是一种统计学概念,用于描述一个概率分布模型的参数对模型输出的敏感程度。在深度学习中,Fisher信息矩阵用于衡量模型参数的不确定性。Fisher信息矩阵在解决灾难性遗忘问题时起到了关键作用,因为它可以帮助我们了解哪些参数在之前的任务中扮演了重要角色,从而在学习新任务时对这些参数施加更大的约束。

在EWC算法中,Fisher信息矩阵的具体体现是一个矩阵,其维度与模型参数相同。Fisher信息矩阵的每个元素表示相应参数的不确定性。较大的Fisher信息值意味着参数在先前任务中的重要性较高,因此在学习新任务时应该施加更大的约束来防止对这些参数进行过度更新。

为了计算Fisher信息矩阵,我们首先对模型进行训练,然后计算每个参数的梯度平方。梯度平方与模型输出的敏感性成正比,因此可以用作Fisher信息矩阵的近似值。计算Fisher信息矩阵时,我们需要遍历整个训练数据集,计算每个数据点的梯度平方,然后求平均值。

最后,在训练新任务时,我们使用Fisher信息矩阵计算EWC惩罚项,该惩罚项与参数变化量的平方成正比。这个惩罚项会在训练过程中添加到原始损失函数中,从而减缓对关键参数的更新,保留在先前任务中学到的知识。

# EWC implementation
class EWC:
    def __init__(self, model, dataloader, device, importance=1000):
        self.model = model
        self.importance = importance
        self.device = device
        self.params = {n: p.clone().detach() for n, p in self.model.named_parameters() if p.requires_grad}
        self.fisher = self._compute_fisher(dataloader)

    def _compute_fisher(self, dataloader):
        fisher = {}
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                fisher[n] = torch.zeros_like(p.data)

        self.model.train()
        for data, target in dataloader:
            data, target = data.to(self.device), target.to(self.device)
            self.model.zero_grad()
            output = F.log_softmax(self.model(data), dim=1)
            loss = F.nll_loss(output, target)
            loss.backward()

            for n, p in self.model.named_parameters():
                if p.requires_grad:
                    fisher[n] += (p.grad ** 2) / len(dataloader)

        return fisher

    def penalty(self, new_model):
        loss = 0
        for n, p in new_model.named_parameters():
            if p.requires_grad:
                _loss = self.fisher[n] * (p - self.params[n]) ** 2
                loss += _loss.sum()
        return loss * (self.importance / 2)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)

#task1_data = [data for data in train_dataset if data[1] < 5]
#task2_data = [data for data in train_dataset if data[1] >= 5]
# Split data into two groups
train_dataset_size = len(train_dataset)
train_split_sizes = [train_dataset_size // 2, train_dataset_size - train_dataset_size // 2]
task1_data, task2_data = random_split(train_dataset, train_split_sizes)



task1_loader = DataLoader(task1_data, batch_size=64, shuffle=True)
task2_loader = DataLoader(task2_data, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

#task1_test_data = [data for data in test_dataset if data[1] < 5]
#task2_test_data = [data for data in test_dataset if data[1] >= 5]
test_dataset_size = len(test_dataset)
test_split_sizes = [test_dataset_size // 2, test_dataset_size - test_dataset_size // 2]
task1_test_data, task2_test_data = random_split(test_dataset, test_split_sizes)

task1_test_loader = DataLoader(task1_test_data, batch_size=64, shuffle=False)
task2_test_loader = DataLoader(task2_test_data, batch_size=64, shuffle=False)


# Model definition
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# EWC implementation
class EWC:
    def __init__(self, model, dataloader, device, importance=1000):
        self.model = model
        self.importance = importance
        self.device = device
        self.params = {n: p.clone().detach() for n, p in self.model.named_parameters() if p.requires_grad}
        self.fisher = self._compute_fisher(dataloader)

    def _compute_fisher(self, dataloader):
        fisher = {}
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                fisher[n] = torch.zeros_like(p.data)

        self.model.train()
        for data, target in dataloader:
            data, target = data.to(self.device), target.to(self.device)
            self.model.zero_grad()
            output = F.log_softmax(self.model(data), dim=1)
            loss = F.nll_loss(output, target)
            loss.backward()

            for n, p in self.model.named_parameters():
                if p.requires_grad:
                    fisher[n] += (p.grad ** 2) / len(dataloader)

        return fisher

    def penalty(self, new_model):
        loss = 0
        for n, p in new_model.named_parameters():
            if p.requires_grad:
                _loss = self.fisher[n] * (p - self.params[n]) ** 2
                loss += _loss.sum()
        return loss * (self.importance / 2)


# Train function
def train(model, dataloader, optimizer, criterion, device, ewc=None, ewc_lambda=0.5):
    model.train()
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

        if ewc is not None:
            ewc_loss = ewc.penalty(model)
            loss += ewc_lambda * ewc_loss

        loss.backward()
        optimizer.step()


# Test function
def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    return accuracy


# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
model = SimpleNet().to(device)

# Train on Task 1
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
    train(model, task1_loader, optimizer, criterion, device)
task1_accuracy = test(model, task1_test_loader, device)
print(f'Task 1 accuracy: {task1_accuracy}%')

# Save EWC
ewc = EWC(model, task1_loader, device)

# Train on Task 2
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(10):
    train(model, task2_loader, optimizer, criterion, device, ewc=ewc, ewc_lambda=10 )
task2_accuracy = test(model, task2_test_loader, device)

print(f'Task 2 accuracy: {task2_accuracy}%')


task1_accuracy_new = test(model, task1_test_loader, device)
print(f'Tasknew 1 accuracy: {task1_accuracy_new}%')
task2_accuracy_NEW = test(model, task2_test_loader, device)
print(f'Tasknew 2 accuracy: {task2_accuracy_NEW}%')



  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值