【代码解析(6)】Communication-Efficient Learning of Deep Networks from Decentralized Data

federated_main.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, average_weights, exp_details

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    # 取出参数
    args = args_parser()

    # 输出实验模型参数等细节
    exp_details(args)

    '''
    if args.gpu_id:
        torch.cuda.set_device(args.gpu_id)
    device = 'cuda' if args.gpu else 'cpu'
    '''
    if args.gpu:
        torch.cuda.set_device(args.gpu)
    device = 'cuda' if args.gpu else 'cpu'

    # load dataset and user groups
    '''
        加载数据集和用户组
    '''
    train_dataset, test_dataset, user_groups = get_dataset(args)
    '''
        print('++++++++++++++++')
        print(user_groups)
        cifar_iid(dataset, num_users)情况下:
        
    '''
    '''
    for dict_key in user_groups.keys():
        print(dict_key)
        0~100
    一共100个用户
    num_users=100
    
    '''

    # BUILD MODEL构建模型
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        '''
            cifar数据集下:
            print('+++++++++++++++++++')
            print(img_size)
            torch.Size([3, 32, 32])
            
            print(train_dataset[0][0].shape)
            torch.Size([3, 32, 32])
            
            print(train_dataset[0][0][0].shape)
            torch.Size([32, 32])
        '''

        len_in = 1
        # torch.Size([3, 32, 32])
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
        '''
            len_in = 3
            MLP(dim_in=3, dim_hidden=64, dim_out=10)
            
            len_in = 96
            MLP(dim_in=94, dim_hidden=64, dim_out=10)
            
            这个符合要求:
            len_in = 3072
            MLP(dim_in=3072, dim_hidden=64, dim_out=10)
            输出:
                torch.Size([10, 10])
        '''
    else:
        exit('Error: unrecognized model')

    '''
        上面已经构建好训练模型了
    '''

    # Set the model to train and send it to device.
    global_model.to(device)

    global_model.train()
    '''
        MLP模型里面有:dropout
        CNN模型里面有:Batch Normalization 和 Dropout
        model.train()的作用是启用 
        Batch Normalization 和 Dropout。
        如果模型中有BN(Batch Normalization)和Dropout,
        需要在训练时添加model.train()。model.train()
        是保证BN层能够用到每一批数据的均值和方差。
        对于Dropout,model.train()是随机取一部分网络连接
        来训练更新参数。
    '''
    # print('+++++++++++++++++++++++')
    print(global_model)
    '''
        对于MLP模型:
            MLP(
                  (layer_input): Linear(in_features=3072, out_features=64, bias=True)
                  (relu): ReLU()
                  (dropout): Dropout(p=0.5, inplace=False)
                  (layer_hidden): Linear(in_features=64, out_features=10, bias=True)
                  (softmax): Softmax(dim=1)
                )
        
    
    '''

    # copy weights 复制权值
    global_weights = global_model.state_dict()
    '''
        state_dict变量存放训练过程中需要学习的权重和偏执系数
        state_dict作为python的字典对象将每一层的参数映射成tensor张量
        需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数
        
        
        
        举例:
        self.conv1=nn.Conv2d(3,6,5)
        self.pool=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)


        model.state_dict()[param_tensor].size()
        conv1.weight 	torch.Size([6, 3, 5, 5])
            63*5*5的卷积核
        conv1.bias 	    torch.Size([6])
            6个偏置
        conv2.weight 	torch.Size([16, 6, 5, 5])
            
        conv2.bias 	    torch.Size([16])
        
        fc1 = nn.Linear(16 * 5 * 5, 120) 权重却是torch.Size([120, 400])
        fc1.weight 	    torch.Size([120, 400])
            
        fc1.bias 	    torch.Size([120])
        
        fc2.weight 	    torch.Size([84, 120])
        fc2.bias 	    torch.Size([84])
        
        fc3.weight 	    torch.Size([10, 84])
        fc3.bias 	    torch.Size([10])

    '''

    # Training
    train_loss, train_accuracy = [], []
    '''
        训练损失,训练准确率
    '''
    val_acc_list, net_list = [], []
    '''
        
    '''
    cv_loss, cv_acc = [], []

    print_every = 2

    val_loss_pre, counter = 0, 0

    # 本地训练,epoch选为10
    '''
        一个epoch指所有的数据送入网络完成一次前向计算和反向传播的过程
        一个epoch中:
            数据分为几个batch
            batch_size为一个batch里面的数据量
    '''
    for epoch in tqdm(range(args.epochs)):
        '''
            args.epochs为options.py中的全局轮数
            
            epoch[0,1,2,3,4,5,6,7,8,9]
            
            tqdm模块tqdm函数 tqdm 是 Python 进度条库
              0%|          | 0/10 [00:00<?, ?it/s]Train Epoch: 1 [0/50000 (0%)]
             10%|| 1/10 [00:58<08:45, 58.42s/it]Train Epoch: 2 [0/50000 (0%)]
        '''
        local_weights, local_losses = [], []

        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        '''
            如果模型中有BN(Batch Normalization)和Dropout,
            需要在训练时添加model.train()。model.train()
            是保证BN层能够用到每一批数据的均值和方差。
            对于Dropout,model.train()是随机取一部分网络连接来训练更新参数
        '''

        # 随机选择10个用户
        m = max(int(args.frac * args.num_users), 1)
        '''
            frac = 0.1
            num_users = 100
            m = 10
        '''

        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        ''' 
            frac = 0.1
            num_users = 100
            dict_users[i] dict类型{'0':{1,3,4}}

            replace表示是否重用元素
            numpy.random.choice(a, size=None, replace=True, p=None)
            a : 如果是一维数组,就表示从这个一维数组中随机采样;如果是int型,
            就表示从0到a-1这个序列中随机采样
            从[0,1,2,3 ... len(dataset)]采样num_items个元素

            这很合理,dataset相当于矩阵,行为user,列为Item
            每个user为一行,列为item数量,所以对每个user采样num_item个元素
            
            idxs_users = np.random.choice(range(args.num_users), m, replace=False)100个用户下标
            随机选择10'''
        # print('++++++++++')
        # print(idxs_users)
        # [33 55 99 17  1 68 31 20 77 93]

        '''
            本地客户端训练
        '''
        for idx in idxs_users:
            '''
                对于每个用户
            '''
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            # print('+++++++++++++++++')
            # print(user_groups[idx])
            '''
                idx的作用是什么?
                返回用户组:
                    dict类型{key:value}
                        key:用户的索引
                        value:这些用户的相应数据
                    
                    user_group就是dict_users
                    dict_users[i]类似{'0':{1,3,4}}
                
                dict_users[i]存的都是下标啊??????!!!!   
                    
                user_groups = get_dataset(args)是utils.py传过来的
                utils.py中的get_dataset是接受的sampling.py传过来的:
                    user_groups = mnist_iid(train_dataset, args.num_users)
                sampling.py中的:
                    dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
                    dict_users[i] = 
                
                
                得到idxs=user_groups[idx]
                
            '''

            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            '''
                调用update.py中的update_weights函数
                for iter in range(self.args.local_ep)
                一个用户经过10个本地batch
                返回:
                return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
            '''

            # print('99999999999999')
            # print(len(w))  # 4
            # print(type(w))  # <class 'collections.OrderedDict'>
            # print(len(w[0]))

            local_weights.append(copy.deepcopy(w))
            '''
                一开始我以为local_weights是接收的w
                其实是local_weights添加的w到自己的list中
            '''

            # print(local_weights)
            '''
            local_weights是<class 'list'>类型:
            
            [OrderedDict([ 
                ('layer_input.weight', tensor([[ 0.0153, -0.0127, -0.0007,  ..., -0.0012,  0.0151,  0.0063],
                    [ 0.0041, -0.0144,  0.0049,  ...,  0.0157,  0.0007, -0.0153],
                    ...,
                    [-0.0050,  0.0112,  0.0106,  ..., -0.0072, -0.0117, -0.0039]])),
                
                ('layer_input.bias', tensor([ 0.0052, -0.0118, -0.0131,  0.0120, ...,  0.0047])),
                ('layer_hidden.weight', tensor([[ 0.0995,  0.1264, -0.0289, -0.0753, ...0.1195],
                                                ..., 
                                                [ 0.0995,  0.1264, -0.0289, -0.0753, ...0.1195]])
                )
                ('layer_hidden.bias', tensor([ 0.0422,  0.0937, -0.1109, -0.1184,  0.0178, -0.0370, -0.0875, -0.0417,
                                            0.1082, -0.0144])
                )
                ])
            ]
                
            '''

            local_losses.append(copy.deepcopy(loss))

        # print('000000000000000000')
        # print(local_weights[9])

        # print('+++++++++++++++++++++++++++')
        # print(local_weights[0]['layer_input.weight'])
        '''
            获得tensor的shape:
                test.shape
                print(local_weights[1]['layer_input.weight'].shape)
                torch.Size([64, 3072])
            3072=3*32*32
            
            (layer_input): Linear(in_features=3072, out_features=64, bias=True)
            fc1 = nn.Linear(16 * 5 * 5, 120) 权重却是torch.Size([120, 400])
            
            得出来的3072正好是上面的:
                args.model == 'mlp':
            
            local_weights[1]['layer_input.weight']:
                tensor([[ 0.0053, -0.0159,  0.0121,  ...,  0.0102,  0.0121,  0.0087],
                        [-0.0181,  0.0125,  0.0130,  ...,  0.0134,  0.0020,  0.0107],
                        ...,
                        [ 0.0023, -0.0054, -0.0015,  ...,  0.0022,  0.0147,  0.0071]])
        '''
        w_avg = copy.deepcopy(local_weights[0])
        # print(len(local_weights))  # 10
        for key in w_avg.keys():
            for i in range(1, len(local_weights)):
                # range(1, 10):1,2,3,4,5,6,7,8,9
                # 他这里加了9个tensor但是后面div(10)'''
                    知道了!!
                    w_avg = copy.deepcopy(local_weights[0])
                    已经将[0]用来初始化w_avg了
                '''

                w_avg[key] += local_weights[i][key]
                #
                '''
                    local_weights[1]['layer_input.weight']:
                        torch.Size([64, 3072])
                    
                    local_weights[1]['layer_input.weight']:
                tensor([[ 0.0053, -0.0159,  0.0121,  ...,  0.0102,  0.0121,  0.0087],
                        [-0.0181,  0.0125,  0.0130,  ...,  0.0134,  0.0020,  0.0107],
                        ...,
                        [ 0.0023, -0.0054, -0.0015,  ...,  0.0022,  0.0147,  0.0071]])
                '''
            # print('+==========================')
            # print(w_avg['layer_input.weight'])
            '''
                类似:
                tensor([[-0.0397, -0.0698,  0.0882,  ..., -0.1417,  0.1702, -0.1217],
                        ...
                        [-0.1882,  0.0659,  0.0673,  ...,  0.0293, -0.0446,  0.0299]])
            '''

            w_avg[key] = torch.div(w_avg[key], len(local_weights))

            '''源代码却是写错了不应该除以len(w)'''
            # w_avg[key] = torch.div(w_avg[key], len(w))
            # print(')))))))))))))))))))))))0')
            # print(w[0])
            # print(w)  # 和local_weights[i]一样
            # print(len(w))  # 4表示w有四个键值对
            # print(type(w))  # <class 'collections.OrderedDict'>
            '''
                w是<class 'list'>类型,
                w_avg[key]每个值除以len(w)=4
                len(w)除以4可就不对了呀
                应该除以len(local_weights)len(w)
            '''

        '''
            print(w_avg.keys())odict_keys(['layer_input.weight', 
                        'layer_input.bias', 
                        'layer_hidden.weight', 
                        'layer_hidden.bias']
        '''
        '''
        OrderedDict有序字典:
        
        print('wwwwwwwwwwwwwww')
        print(local_weights)
            [OrderedDict([('layer_input.weight', tensor([[-0.0026,  0.0088,  0.0002,  ...,  0.0076,  0.0010,  0.0020],
            [-0.0053,  0.0178,  0.0017,  ..., -0.0125,  0.0014, -0.0052],
            [-0.0076,  0.0077,  0.0046,  ...,  0.0105,  0.0071,  0.0129],
        print(local_weights[0]) 
            OrderedDict([('layer_input.weight', tensor([[-0.0068, -0.0142,  0.0133,  ..., -0.0147,  0.0048,  0.0156],
            [ 0.0089, -0.0090, -0.0009,  ..., -0.0061, -0.0013, -0.0005],
        '''
        # print('类型')
        # print(type(local_weights))  # <class 'list'>
        # print(type(local_weights[0]))  # <class 'collections.OrderedDict'>

        # update global weights
        '''
            更新全局权重
            传参为local_weights
            local_weights.append(copy.deepcopy(w))
            
            len(local_weights) = 10
        '''
        # print('传参local_weights')
        # print(len(local_weights))

        # 一个一个用户更新全局权重
        global_weights = average_weights(local_weights)
        '''
            上面一个for循环存在没有意义
            得到的w_avg没有用
            上面一个for循环在uitls.py中的average_weights()函数中
            global_weights = average_weights(local_weights)
            调用了这个函数
        '''

        # update global weights
        global_model.load_state_dict(global_weights)
        '''
            torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中
        '''

        # print('loooooooooooooooooo')
        # print(local_losses)
        loss_avg = sum(local_losses) / len(local_losses)
        '''
            local_losses为list类型
            [-0.1604325039871037, -0.14957306474447252, 
            -0.1479335972107947, -0.17096682694740595, 
            -0.15407370103523135, -0.15217236945405604, 
            -0.14514834607020022, -0.1494329896569252, 
            -0.1533350457623601, -0.1353217322193086]
        '''

        train_loss.append(loss_avg)
        '''
            [[], [], []....]
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            loss由这里来的:
            update.py中的update_weights函数:
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
            下面: loss = local_model.inference(model=global_model)
            由update.py中的inference函数:
                outputs = model(images)
                batch_loss = self.criterion(outputs, labels)
            结果一样的
            
        '''

        # Calculate avg training accuracy over all users at every epoch
        '''
            计算每一epoch所有用户平均训练精度和损失
        '''
        list_acc, list_loss = [], []

        global_model.eval()

        for c in range(args.num_users):
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        '''
            每个用户推断的精度和损失存到
            list_acc
            list_loss
            这里的list_loss没用
            用的是:
                train_loss.append(loss_avg)
        '''

        train_accuracy.append(sum(list_acc)/len(list_acc))
        # print('000000000000')
        # print(train_accuracy)  # [0.2199999999999999]
        # print(train_accuracy[-1])  # 0.2199999999999999

        # print global training loss after every 'i' rounds
        '''每一轮输出全局训练损失 print_every=2'''

        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')

            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            '''
                train_loss:[[], [], [],...]
                np.array(train_loss):
                    [[]
                     []
                     []
                     ]
                np.mean(np.array(train_loss)):
                    所有值之和除以总的数
            '''

            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

    # Test inference after completion of training
    '''
        完成所有轮训练的推断损失
    '''
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')

    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))

    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # Saving the objects train_loss and train_accuracy:
    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))


  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值