上一篇博客实验到了FedAvg的实现,模型是用的最简单的LogisticRegression,只用了一层的线性层。
class LogisticRegression(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super(LogisticRegression, self).__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x):
outputs = torch.sigmoid(self.linear(x))
return outputs
最后的准确率最高只能到了80%左右,下图为每次选择100个用户,通信10轮的测试情况。
所以将模型换成了CNN进行测试。
class CNN_OriginalFedAvg(torch.nn.Module):
"""The CNN model used in the original FedAvg paper:
"Communication-Efficient Learning of Deep Networks from Decentralized Data"
https://arxiv.org/abs/1602.05629.
The number of parameters when `only_digits=True` is (1,663,370), which matches
what is reported in the paper.
When `only_digits=True`, the summary of returned model is
Model:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
reshape (Reshape) (None, 28, 28, 1) 0
_________________________________________________________________
conv2d (Conv2D) (None, 28, 28, 32) 832
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 14, 14, 64) 51264
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 3136) 0
_________________________________________________________________
dense (Dense) (None, 512) 1606144
_________________________________________________________________
dense_1 (Dense) (None, 10) 5130
=================================================================
Total params: 1,663,370
Trainable params: 1,663,370
Non-trainable params: 0
Args:
only_digits: If True, uses a final layer with 10 outputs, for use with the
digits only MNIST dataset (http://yann.lecun.com/exdb/mnist/).
If False, uses 62 outputs for Federated Extended MNIST (FEMNIST)
EMNIST: Extending MNIST to handwritten letters: https://arxiv.org/abs/1702.05373.
Returns:
A `torch.nn.Module`.
"""
def __init__(self, only_digits=True):
super(CNN_OriginalFedAvg, self).__init__()
self.only_digits = only_digits
self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=5, padding=2)
self.max_pooling = nn.MaxPool2d(2, stride=2)
self.conv2d_2 = torch.nn.Conv2d(32, 64, kernel_size=5, padding=2)
self.flatten = nn.Flatten()
self.linear_1 = nn.Linear(3136, 512)
self.linear_2 = nn.Linear(512, 10 if only_digits else 62)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = torch.unsqueeze(x, 1)
x = self.conv2d_1(x)
x = self.max_pooling(x)
x = self.conv2d_2(x)
x = self.max_pooling(x)
x = self.flatten(x)
x = self.relu(self.linear_1(x))
x = self.softmax(self.linear_2(x))
return x
需要改动地方还有:fedml_api/standalone/fedavg/client.py
if(self.args.model=='cnn'):
x = x.view(len(x), 28, 28)
在模型训练和测试之前需要改变一下张量的维度。
def train(self, w_global):
self.model.load_state_dict(w_global)
self.model.to(self.device)
# train and update
if self.args.client_optimizer == "sgd":
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr)
else:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr,
weight_decay=self.args.wd, amsgrad=True)
epoch_loss = []
for epoch in range(self.args.epochs):
batch_loss = []
for batch_idx, (x, labels) in enumerate(self.local_training_data):
x, labels = x.to(self.device), labels.to(self.device)
if(self.args.model=='cnn'):
x = x.view(len(x), 28, 28)
# logging.info("x.size = " + str(x.size()))
# logging.info("x.size = " + str(x.size()))
# logging.info("labels.size = " + str(labels.size()))
self.model.zero_grad()
log_probs = self.model(x)
loss = self.criterion(log_probs, labels)
loss.backward()
# to avoid nan loss
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
optimizer.step()
# logging.info('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
# epoch, (batch_idx + 1) * self.args.batch_size, len(self.local_training_data) * self.args.batch_size,
# 100. * (batch_idx + 1) / len(self.local_training_data), loss.item()))
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss) / len(batch_loss))
# logging.info('Client Index = {}\tEpoch: {}\tLoss: {:.6f}'.format(
# self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss)))
return self.model.cpu().state_dict(), sum(epoch_loss) / len(epoch_loss)
def local_test(self, model_global, b_use_test_dataset=False):
model_global.eval()
model_global.to(self.device)
metrics = {
'test_correct': 0,
'test_loss' : 0,
'test_precision': 0,
'test_recall': 0,
'test_total' : 0
}
if b_use_test_dataset:
test_data = self.local_test_data
else:
test_data = self.local_training_data
with torch.no_grad():
for batch_idx, (x, target) in enumerate(test_data):
x = x.to(self.device)
target = target.to(self.device)
if (self.args.model == 'cnn'):
x = x.view(len(x), 28, 28)
pred = model_global(x)
loss = self.criterion(pred, target)
if self.args.dataset == "stackoverflow_lr":
predicted = (pred > .5).int()
correct = predicted.eq(target).sum(axis = -1).eq(target.size(1)).sum()
true_positive = ((target * predicted) > .1).int().sum(axis = -1)
precision = true_positive / (predicted.sum(axis = -1) + 1e-13)
recall = true_positive / (target.sum(axis = -1) + 1e-13)
metrics['test_precision'] += precision.sum().item()
metrics['test_recall'] += recall.sum().item()
else:
_, predicted = torch.max(pred, -1)
correct = predicted.eq(target).sum()
metrics['test_correct'] += correct.item()
metrics['test_loss'] += loss.item() * target.size(0)
metrics['test_total'] += target.size(0)
return metrics
最终的结果准确率可以达到百分之九十。
参数设置:
通信轮数都是500次,灰色线是LR模型,红色是CNN模型。
sh run_fedavg_standalone_pytorch.sh 0 1000 100 -1 mnist ./../../../data/mnist lr hetero 500 1 0.03 sgd 0
sh run_fedavg_standalone_pytorch.sh 0 1000 100 -1 mnist ./../../../data/mnist cnn hetero 500 1 0.03 sgd 0
下面是每轮通信都是1000个用户共同参与,需要把内存开到12G。之后测试集中训练情况下时(即in_total和per_round都为1),即使把虚拟机内存开到最大也是被kill掉,不清楚是不是内存不够的原因。
sh run_fedavg_standalone_pytorch.sh 0 1000 1000 -1 mnist ./../../../data/mnist cnn hetero 10 1 0.03 sgd 0