使用增量学习中EWC方法来做回归简单示例

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 生成随机时序数据
def generate_time_series_data(num_samples, sequence_length):
    X = torch.randn(num_samples, sequence_length, 1)  # Assuming a univariate time series
    y = (X.sum(dim=(1, 2)) + 0.1 * torch.randn(num_samples)).view(-1, 1)
    return X, y

# 定义简单的LSTM模型
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        _, (h_n, _) = self.lstm(x)
        x = self.fc(h_n[-1])
        return x

# 计算 Fisher 信息
def calculate_fisher(model, dataloader, device):
    fisher_info = []
    model.eval()

    criterion = nn.MSELoss()

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)

        # Compute loss and backward pass
        loss = criterion(outputs, labels)
        model.zero_grad()
        loss.backward()

        # Extract gradients from the model parameters
        gradients = [param.grad.flatten().detach().cpu().numpy() for param in model.parameters()]
        fisher_info.append(np.square(np.concatenate(gradients)) / len(dataloader.dataset))

    fisher_info = np.mean(fisher_info, axis=0)
    return fisher_info

# 定义 EWC 损失
def ewc_loss(fisher_information, weight, weight_old, lambda_):
    return lambda_ / 2 * torch.sum(torch.tensor(fisher_information) * (weight - weight_old) ** 2)

# 初始化模型、数据和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleLSTM(input_size=1, hidden_size=64, output_size=1).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# 生成初始时序数据并进行初始训练
X_initial, y_initial = generate_time_series_data(100, 10)
initial_dataset = torch.utils.data.TensorDataset(X_initial, y_initial)
initial_dataloader = torch.utils.data.DataLoader(initial_dataset, batch_size=32, shuffle=True)

# 初始训练
for epoch in range(50):
    for inputs, labels in initial_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 打印初始训练结果
model.eval()
with torch.no_grad():
    initial_predictions = model(X_initial.to(device)).cpu().numpy()

# 绘制初始训练结果图
plt.plot(y_initial.numpy(), label='Actual data')
plt.plot(initial_predictions, label='Initial predictions', color='red')
plt.title('Initial Training Results')
plt.xlabel('Time step')
plt.ylabel('y')
plt.legend()
plt.show()

# 计算并存储 Fisher 信息
fisher_info = calculate_fisher(model, initial_dataloader, device)

# 模拟增量学习的场景,生成新时序数据并进行增量学习
X_new, y_new = generate_time_series_data(50, 10)
new_dataset = torch.utils.data.TensorDataset(X_new, y_new)
new_dataloader = torch.utils.data.DataLoader(new_dataset, batch_size=32, shuffle=True)

# 使用 EWC 进行增量学习
for epoch in range(20):
    for inputs, labels in new_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # EWC loss
        if epoch > 0:
            current_weights = torch.cat([param.view(-1) for param in model.parameters()])
            ewc_loss_value = ewc_loss(fisher_info, current_weights, prev_weights, 0.1)
            loss += ewc_loss_value

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 保存当前权重用于下一次计算 EWC 损失
    prev_weights = torch.cat([param.view(-1) for param in model.parameters()]).detach().clone()

# 打印增量学习后的结果
model.eval()
with torch.no_grad():
    new_predictions = model(X_initial.to(device)).cpu().numpy()

# 绘制增量学习后的结果图
plt.plot(y_initial.numpy(), label='Actual data')
plt.plot(new_predictions, label='Incremental learning predictions', color='green')
plt.title('Incremental Learning Results')
plt.xlabel('Time step')
plt.ylabel('y')
plt.legend()
plt.show()

  • 7
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值