import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# 生成随机时序数据
def generate_time_series_data(num_samples, sequence_length):
X = torch.randn(num_samples, sequence_length, 1) # Assuming a univariate time series
y = (X.sum(dim=(1, 2)) + 0.1 * torch.randn(num_samples)).view(-1, 1)
return X, y
# 定义简单的LSTM模型
class SimpleLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleLSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
_, (h_n, _) = self.lstm(x)
x = self.fc(h_n[-1])
return x
# 计算 Fisher 信息
def calculate_fisher(model, dataloader, device):
fisher_info = []
model.eval()
criterion = nn.MSELoss()
for inputs, labels in dataloader:
inputs, labels = inputs.to(device), labels.to(device)
# Forward pass
outputs = model(inputs)
# Compute loss and backward pass
loss = criterion(outputs, labels)
model.zero_grad()
loss.backward()
# Extract gradients from the model parameters
gradients = [param.grad.flatten().detach().cpu().numpy() for param in model.parameters()]
fisher_info.append(np.square(np.concatenate(gradients)) / len(dataloader.dataset))
fisher_info = np.mean(fisher_info, axis=0)
return fisher_info
# 定义 EWC 损失
def ewc_loss(fisher_information, weight, weight_old, lambda_):
return lambda_ / 2 * torch.sum(torch.tensor(fisher_information) * (weight - weight_old) ** 2)
# 初始化模型、数据和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleLSTM(input_size=1, hidden_size=64, output_size=1).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 生成初始时序数据并进行初始训练
X_initial, y_initial = generate_time_series_data(100, 10)
initial_dataset = torch.utils.data.TensorDataset(X_initial, y_initial)
initial_dataloader = torch.utils.data.DataLoader(initial_dataset, batch_size=32, shuffle=True)
# 初始训练
for epoch in range(50):
for inputs, labels in initial_dataloader:
inputs, labels = inputs.to(device), labels.to(device)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印初始训练结果
model.eval()
with torch.no_grad():
initial_predictions = model(X_initial.to(device)).cpu().numpy()
# 绘制初始训练结果图
plt.plot(y_initial.numpy(), label='Actual data')
plt.plot(initial_predictions, label='Initial predictions', color='red')
plt.title('Initial Training Results')
plt.xlabel('Time step')
plt.ylabel('y')
plt.legend()
plt.show()
# 计算并存储 Fisher 信息
fisher_info = calculate_fisher(model, initial_dataloader, device)
# 模拟增量学习的场景,生成新时序数据并进行增量学习
X_new, y_new = generate_time_series_data(50, 10)
new_dataset = torch.utils.data.TensorDataset(X_new, y_new)
new_dataloader = torch.utils.data.DataLoader(new_dataset, batch_size=32, shuffle=True)
# 使用 EWC 进行增量学习
for epoch in range(20):
for inputs, labels in new_dataloader:
inputs, labels = inputs.to(device), labels.to(device)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# EWC loss
if epoch > 0:
current_weights = torch.cat([param.view(-1) for param in model.parameters()])
ewc_loss_value = ewc_loss(fisher_info, current_weights, prev_weights, 0.1)
loss += ewc_loss_value
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 保存当前权重用于下一次计算 EWC 损失
prev_weights = torch.cat([param.view(-1) for param in model.parameters()]).detach().clone()
# 打印增量学习后的结果
model.eval()
with torch.no_grad():
new_predictions = model(X_initial.to(device)).cpu().numpy()
# 绘制增量学习后的结果图
plt.plot(y_initial.numpy(), label='Actual data')
plt.plot(new_predictions, label='Incremental learning predictions', color='green')
plt.title('Incremental Learning Results')
plt.xlabel('Time step')
plt.ylabel('y')
plt.legend()
plt.show()
12-11
562
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
03-09
02-09
1017
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
07-09
3758
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
06-10
“相关推荐”对你有帮助么?
-
非常没帮助
-
没帮助
-
一般
-
有帮助
-
非常有帮助
提交