接上一贴,讲解FedAvg代码。
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
from utils.sampling import mnist_iid, mnist_noniid, cifar_iid # 引入了三种iid与non-iid数据库类型
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img
这里和集中式训练差不多,自定义包中的具体函数在使用时候再具体讲解。
解析命令行参数:
args = args_parser() # 读取options.py中的参数信息
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') # 使用cpu还是gpu 赋值args.device
和集中式训练一样,不再赘述。
接下来是加载数据集和拆分客户端,代码如下:
# load dataset and split users
if args.dataset == 'mnist':
trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, transform=trans_mnist) #训练集
dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True, transform=trans_mnist) #测试集
# sample users
if args.iid:
dict_users = mnist_iid(dataset_train, args.num_users) # 为用户分配iid数据
else:
dict_users = mnist_noniid(dataset_train, args.num_users) # 否则为用户分配non-iid数据
elif args.dataset == 'cifar':
trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = datasets.CIFAR10('./data/cifar', train=True, download=True, transform=trans_cifar) #训练集
dataset_test = datasets.CIFAR10('./data/cifar', train=False, download=True, transform=trans_cifar) #测试集
if args.iid:
dict_users = cifar_iid(dataset_train, args.num_users) # 为用户分配iid数据
else:
exit('Error: only consider IID setting in CIFAR10') # cifar未设置non-iid数据
else:
exit('Error: unrecognized dataset')
img_size = dataset_train[0][0].shape # 图像的size
在这里引入了三种数据集分配方式,分别是minist_iid,mnist_noniid和cifar_iid。具体代码如下:
mnist_iid是将mnist数据集进行iid划分:
def mnist_iid(dataset, num_users):
"""
Sample I.I.D. client data from MNIST dataset
:param dataset: 数据集,dataset_train
:param num_users: 客户端数量,args.num_users
:return: dict of image index 实例:{0: {0, 9}, 1: {5, 7}, 2: {2, 4}, 3: {3, 6}, 4: {8, 1}} 五个客户端,每个客户端有两个数据集的index
"""
num_items = int(len(dataset)/num_users) # 每个客户端的数据量
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
# set()函数创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可以计算交集、差集、并集等。
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) # 从all_idxs中采样,生成数据量无重复的一个字典,作为dict_users的第i个元素
all_idxs = list(set(all_idxs) - dict_users[i]) # 取差集,删去已经被分配好的数据,直至每个用户都被分配了等量的iid数据
return dict_users
noniid划分则是将数据集划分为200组大小为300的数据切片,然后分给每个Client两个切片:
def mnist_noniid(dataset, num_users):
"""
Sample non-I.I.D client data from MNIST dataset
:param dataset: 数据集,dataset_train
:param num_users: 客户端数量,args.num_users
:return:
"""
num_shards, num_imgs = 200, 300
idx_shard = [i for i in range(num_shards)]
dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} # 初始化字典dict_users {0: array([], dtype=int64), 1: array([], dtype=int64), ...}
idxs = np.arange(num_shards*num_imgs) # [0,1,2,...,59999]
labels = dataset.train_labels.numpy() # .numpy()输出的是值,而dataset.train_labels输出的是张量 # [5 0 ... 5 6 8] shape:(6000,)
# sort labels
idxs_labels = np.vstack((idxs, labels)) # 沿着第一个轴堆叠数组 # [2,60000] 第一行是index,第二行是标签
idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] # y=x.argsort() 将x中的元素从小到大排列,提取其对应的index(索引),然后输出到y
# 将标签升序排序,对应关系不变 [[0 1 2 3] [[0 1 3 2] 第二行升序排列
# [0 1 9 5]] -> [0 1 5 9]]
idxs = idxs_labels[0,:] # 排序后的index顺序
# divide and assign
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 2, replace=False)) # [0, 1, 2, 199]中随机选择两个组成set
idx_shard = list(set(idx_shard) - rand_set) # 去除随机选择的集合
for rand in rand_set:
dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) # concatenate进行矩阵拼接
return dict_users
cifar_iid的分配方式与mnist_iid一模一样:
def cifar_iid(dataset, num_users):
"""
Sample I.I.D. client data from CIFAR10 dataset
:param dataset: 数据集,dataset_train
:param num_users: 客户端数量,args.num_users
:return: dict of image index
"""
num_items = int(len(dataset)/num_users) # 与mnist_iid一样
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
这三个函数最终的输出都是dict_users,也就是一个字典,键值对分别为客户端编号和每个客户端所拥有的数据集标签。
接下来是构建模型,与集中式类似:
if args.model == 'cnn' and args.dataset == 'cifar':
net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
print(net_glob)
net_glob.train()
接下来是训练过程,一开始是复制权重:
w_glob = net_glob.state_dict()
torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏置系数 state_dict本质上Python字典对象
loss_train = []
cv_loss, cv_acc = [], []
val_loss_pre, counter = 0, 0
net_best = None
best_loss = None
val_acc_list, net_list = [], []
if args.all_clients: # 选择利用所有客户进行聚合
print("Aggregation over all clients")
w_locals = [w_glob for i in range(args.num_users)] # 每一个local的w与全局w相等
for iter in range(args.epochs):
loss_locals = [] # 对于每一个epoch,初始化worker的损失
if not args.all_clients: # 如果不是用所有用户进行聚合
w_locals = [] # 此时worker的w与全局w并不一致
m = max(int(args.frac * args.num_users), 1)
# 此时,在每一轮中,在所有worker中选取C-fraction(C∈(0,1))部分进行训练,m为选取的worker总数
if args.sample_by_proportion:
for i in range(args.num_users):
local_data_volume = len(dict_users[i])
total_data_volume = sum(local_data_volume)
probability = local_data_volume / total_data_volume
idxs_users = np.random.choice(range(args.num_users), m, replace=False, p = probability)
else:
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
# 在所有的worker中(0,1,2...num_workers-1)选取m个worker(m = all_workers * C_fraction),且输出不重复
for idx in idxs_users: # 对于选取的m个worker
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) # 对每个worker进行本地更新
#TODO:了解LocalUpdate(),函数的输出是什么? (7.24 15:27)
#GET: 本地训练,返回net.state_dict()和平均loss
w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device)) # 本地训练的weight和loss
if args.all_clients:
w_locals[idx] = copy.deepcopy(w)
else:
w_locals.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss))
# update global weights
w_glob = FedAvg(w_locals) # 利用选取的局部w对全局w进行聚合更新,w_glob即为全局聚合更新后的值
#TODO:了解FedAvg(),输出为全局模型? (7.24 15:34)
#GET:联邦平均,输出为更新后的全局模型
# copy weight to net_glob
net_glob.load_state_dict(w_glob)
# print loss
loss_avg = sum(loss_locals) / len(loss_locals)
print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
loss_train.append(loss_avg)
这里面有一个是为了根据数据量确定采样概率
LocalUpdate函数为客户端本地训练:
class DatasetSplit(Dataset):
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = list(idxs)
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return image, label
# local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
class LocalUpdate(object):
def __init__(self, args, dataset=None, idxs=None):
self.args = args
self.loss_func = nn.CrossEntropyLoss()
self.selected_clients = []
self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
def train(self, net):
net.train()
# train and update
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
epoch_loss = []
for iter in range(self.args.local_ep):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
images, labels = images.to(self.args.device), labels.to(self.args.device)
net.zero_grad()
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
loss.backward()
optimizer.step()
if self.args.verbose and batch_idx % 10 == 0:
print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(iter, batch_idx * len(images), len(self.ldr_train.dataset),100. * batch_idx / len(self.ldr_train), loss.item()))
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
与集中式训练相似
FedAvg则是中心聚合:
def FedAvg(w): # main中将w_locals赋给w,即worker计算出的权值
w_avg = copy.deepcopy(w[0])
for k in w_avg.keys(): # 对于每个参与的设备:
for i in range(1, len(w)): # 对本地更新进行聚合
w_avg[k] += w[i][k]
w_avg[k] = torch.div(w_avg[k], len(w))
return w_avg
最后是损失绘图和测试,和集中式训练类似。