I. 前言
Per-FedAvg的原理请见:arXiv | Per-FedAvg:一种联邦元学习方法。
II. 数据集介绍
联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。
数据集为某城市十个地区的风电功率,我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。
III. Per-FedAvg
算法伪代码:
1. 服务器端
服务器端和FedAvg一致,这里不再详细介绍了,可以看看前面几篇文章。
2. 客户端
对于每个客户端,我们定义它的元函数
F
i
(
w
)
F_i(w)
Fi(w):
为了在本地训练中对
F
i
(
w
)
F_i(w)
Fi(w)进行更新,我们需要计算其梯度:
代码实现如下:
def train(args, model):
model.train()
Dtr, Dte = nn_seq_wind(model.name, args.B)
model.len = len(Dtr)
print('training...')
data = [x for x in iter(Dtr)]
for epoch in range(args.E):
origin_model = copy.deepcopy(model)
final_model = copy.deepcopy(model)
# step1
model = one_step(args, data, model, lr=args.alpha)
# step2
model = get_grad(args, data, model)
# step3
hessian_params = get_hessian(args, data, origin_model)
# step 4
cnt = 0
for param, param_grad in zip(final_model.parameters(), model.parameters()):
hess = hessian_params[cnt]
cnt += 1
I = torch.ones_like(param.data)
grad = (I - args.alpha * hess) * param_grad.grad.data
param.data = param.data - args.beta * grad
model = copy.deepcopy(final_model)
return model
def one_step(args, data, model, lr):
ind = np.random.randint(0, high=len(data), size=None, dtype=int)
seq, label = data[ind]
seq = seq.to(args.device)
label = label.to(args.device)
y_pred = model(seq)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.MSELoss().to(args.device)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model
def get_grad(args, data, model):
ind = np.random.randint(0, high=len(data), size=None, dtype=int)
seq, label = data[ind]
seq = seq.to(args.device)
label = label.to(args.device)
y_pred = model(seq)
loss_function = nn.MSELoss().to(args.device)
loss = loss_function(y_pred, label)
loss.backward()
return model
def get_hessian(args, data, model):
ind = np.random.randint(0, high=len(data), size=None, dtype=int)
seq, label = data[ind]
seq = seq.to(args.device)
label = label.to(args.device)
y_pred = model(seq)
loss_function = nn.MSELoss().to(args.device)
loss = loss_function(y_pred, label)
grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True, create_graph=True)
hessian_params = []
for k in range(len(grads)):
hess_params = torch.zeros_like(grads[k])
for i in range(grads[k].size(0)):
# w or b?
if len(grads[k].size()) == 2:
for j in range(grads[k].size(1)):
hess_params[i, j] = torch.autograd.grad(grads[k][i][j], model.parameters(), retain_graph=True)[k][i, j]
else:
hess_params[i] = torch.autograd.grad(grads[k][i], model.parameters(), retain_graph=True)[k][i]
hessian_params.append(hess_params)
return hessian_params
这里涉及到求解hessian矩阵,可以参考我写的另一篇文章:PyTorch计算损失函数对模型参数的Hessian矩阵。
3. 本地自适应更新
得到初始模型后,需要在本地进行1轮迭代更新:
def local_adaptation(args, model):
model.train()
Dtr, Dte = nn_seq_wind(model.name, 50)
optimizer = torch.optim.Adam(model.parameters(), lr=args.alpha)
loss_function = nn.MSELoss().to(args.device)
loss = 0
for epoch in range(1):
for seq, label in Dtr:
seq, label = seq.to(args.device), label.to(args.device)
y_pred = model(seq)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print('local_adaptation loss', loss.item())
return model
IV. 源码
后面将陆续公开~