Communication-Efficient Learning of Deep Networksfrom Decentralized Data 代码全注释

文献:https://arxiv.org/abs/1602.05629

代码来源:https://github.com/shaoxiongji/federated-learning

参考文章:FedAvg代码详解-CSDN博客

目录

1、Utils

options.py

sampling.py

2、models

Nets.py

Update.py

Fed.py

test.py

main_fed.py

main_nn.py


1、Utils

options.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse  # 用于命令行选项,参数和子命令的解释

def args_parser():
    parser = argparse.ArgumentParser()
    # --epochs是参数名称
    # type是参数类型,从命令行输入的参数默认是字符串类型
    # default若参数不输入,默认使用该值

    # 这里进行了三类参数的设置,分别是联邦参数,模型参数,其他参数。
    # federated arguments
    # 联邦参数:
    # epochs:训练轮数;
    # num_users:用户数量;
    # frac:用户选取比例;
    # local_ep:本地训练轮数
    # local_bs:本地训练批大小
    # bs:测试批大小
    # lr:学习率
    # momentum:SGD动量
    # split:测试集划分类型,是用户还是样本
    parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=128, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

    # model arguments
    # model 模型名称
    # kernel_num 卷积核数量
    # kernel_size 卷积核大小
    # norm 归一化方式
    # num_filters 过滤器数量
    # max_pool 最大池化
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")

    # other arguments
    # dataset 数据集选择
    # iid 独立同分布默认
    # num_classes 分类数量
    # num_channels 图像通道数
    # gpu 默认使用
    # stopping_rounds 停止轮数
    # verbose 日志
    # seed 随机数种子
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    # 命令行中出现了--iid选项,则该选项的值被设置为True,否则为False。
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    # 命令行中出现了--verbose选项,则该选项的值被设置为True,即显示日志,否则为False。
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    # 出现--all_clients 则表示全部客户端都参与
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    args = parser.parse_args()
    # 进行参数解析,可以使用args.epochs调用该值
    # 使用命令行运行 如 python test.py --epochs 100
    return args

sampling.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import numpy as np
from torchvision import datasets, transforms

def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset: 数据集
    :param num_users: 用户数量
    :return: dict of image index 返回一个字典,键为用户编号,值为分配给用户的样本索引集合
    使得每个用户获得相同数量的随机样本
    """
    num_items = int(len(dataset)/num_users) #计算每个用户获得的样本数量
    dict_users, all_idxs = {}, [i for i in range(len(dataset))] #创建一个空字典和列表,从0到样本数减1的整数依次放入列表中。
    for i in range(num_users):
        # np.random.choice函数从all_idxs列表中随机选择num_items个元素,且不允许重复选择(replace = False),将结果赋给第i个键对应的键值
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i]) #从全部的索引中删除已经分配的索引
    return dict_users


def mnist_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # num_shards 表示分片数量,num_imgs 表示每个分片中的图像数量,一共有60000个训练图片
    num_shards, num_imgs = 200, 300
    idx_shard = [i for i in range(num_shards)] # 将0-199索引存入idx_shard
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} #创建字典,包含num_users个键,每个键对应一个空的int64类型的Numpy数组
    idxs = np.arange(num_shards*num_imgs) # 是一个包含所有样本索引的一维数组,范围从0到num_shards*num_imgs-1。
    labels = dataset.train_labels.numpy() # 提取数据集中的训练标签,并将其转换为NumPy数组

    # sort labels
    idxs_labels = np.vstack((idxs, labels)) # 将索引和标签进行堆叠形成一个二维数组,一行索引,一行标签
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] # 根据标签行的值对样本索引和标签进行重新排序,从小到大
    idxs = idxs_labels[0,:] # 将排序好的索引赋给idxs

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False)) #从200个索引中随机选择两个放在rand_set中
        idx_shard = list(set(idx_shard) - rand_set) # 去除已经选择的索引
        for rand in rand_set:
            # 将随机的两片rand索引与字典中的值进行连接,并赋值给字典
            dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users


def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


if __name__ == '__main__': # 表示如果变量等于'__main__'则表示直接运行当前脚本
    # 创建 MNIST 数据集的训练集实例,并对图像数据进行预处理,transforms.Compose 是一种组合多个图像预处理操作的方法
    dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),# 张量转换
                                       transforms.Normalize((0.1307,), (0.3081,)) # 图像归一化处理
                                   ]))
    num = 100
    d = mnist_noniid(dataset_train, num)

2、models

Nets.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out): # 输入维度,隐藏层维度,输出维度
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden) # 线性层,将输入维度映射到输出维度
        self.relu = nn.ReLU() # relu激活函数,进行非线性变换
        self.dropout = nn.Dropout() # 随机丢弃神经元,防止过拟合
        self.layer_hidden = nn.Linear(dim_hidden, dim_out) # 线性层,调整输出维度

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) # 将张量x从任意形状的多维张量展平为一个二维张量,第一维度是根据-1自动推断的。
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x


class CNNMnist(nn.Module):
    def __init__(self, args): # 接收一个参数
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5) # 两次卷积操作
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d() # 二维dropout层,用于随机丢弃部分特征图
        self.fc1 = nn.Linear(320, 50)  # 全连接层,320维到50维
        self.fc2 = nn.Linear(50, args.num_classes) # 50维映射到预测的类别数量

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))  # 经过卷积,池化,relu激活函数
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Update.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics


class DatasetSplit(Dataset): # 构建数据集
    def __init__(self, dataset, idxs): # 接收数据集以及索引
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):  # 返回构建的数据集大小
        return len(self.idxs)

    def __getitem__(self, item): # 返回索引为self.idxs[item]处的图像和标签数据
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object): # 本地更新模型构建模块
    def __init__(self, args, dataset=None, idxs=None):# args 是一些训练参数 dataset 是整个数据集 idxs 是当前客户端用于训练的样本索引列表。
        self.args = args
        self.loss_func = nn.CrossEntropyLoss() # 交叉熵损失函数创建
        self.selected_clients = [] # 空的客户端选择列表
        # DatasetSplit(dataset, idxs) 创建一个只包含idxs索引对应的子集数据的数据集对象 用DataLoader进行数据加载,设置批大小,以及每轮训练数据打乱
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net): # 本地模型训练,接收一个网络模型
        net.train() # 设置为训练模式
        # train and update
        # 设置一个梯度下降优化器,用lr和momentum进行优化,即学习率和动量
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        epoch_loss = [] # 存储每个迭代周期的损失值
        for iter in range(self.args.local_ep): # 迭代本地epoch训练
            batch_loss = [] # 存储每个批次的损失值
            # 通过 enumerate(self.ldr_train) 遍历数据加载器,获取每个批次的图像数据 images 和标签 labels。
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                # 将数据加载到指定的设备上,通常是将数据移动到GPU上进行加速计算,设备由device决定
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad() # 将网络的梯度清零
                log_probs = net(images) # 通过向前传播计算网络模型对图像的预测值
                loss = self.loss_func(log_probs, labels) # 利用损失函数计算损失值
                loss.backward() # 反向传播计算梯度
                optimizer.step() # 优化器更新网络参数
                if self.args.verbose and batch_idx % 10 == 0: # 是否打印日志,控制打印频率
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item()) # 将损失值 loss.item() 添加到 batch_loss 列表中
            # 在每次迭代结束后,计算当前迭代周期中的平均损失值 sum(batch_loss)/len(batch_loss),并将其添加到 epoch_loss 列表中
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss) # 返回网络的状态以及本地epoch_loss平均值

Fed.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import torch
from torch import nn


def FedAvg(w): # w为多个模型参数
    # 生成了一个新的独立的对象w_avg,使得w_avg与w[0]的值相同但是不共享内存。这种操作常用于避免在后续操作中对w_avg的修改影响到原始的w[0]
    # w[0]是其中一组模型参数的索引
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys(): # 遍历所有的键
        for i in range(1, len(w)): # 遍历所有的模型参数
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w)) # 求均值
    return w_avg

test.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @python: 3.6

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


def test_img(net_g, datatest, args): # 用于评估训练好的模型在测试集上的性能 三个参数,模型,数据集,其它参数
    net_g.eval() # 将模型设置为评估模式,这意味着在推理阶段不会进行梯度计算。
    # testing
    test_loss = 0 # 初始化损失
    correct = 0 # 正确分类的数量
    data_loader = DataLoader(datatest, batch_size=args.bs) # 数据载入
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader): # 在每次循环中,从数据加载器中获取一批测试样本和对应的标签
        if args.gpu != -1: # 用GPU计算
            data, target = data.cuda(), target.cuda()
        log_probs = net_g(data) # data传入模型,得到预测输出
        # sum up batch loss 计算当前批次的损失
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability 通过这段代码,可以得到log_probs张量中每一行的最大值对应的索引,即预测结果y_pred
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        # 比较y_pred和target是否相等,返回一个布尔张量,long转换成长征型,移到CPU上,累加
        # 统计预测值y_pred和目标值target之间匹配正确的数量,并将这个数量累加到变量correct中
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset) # 计算平均损失值
    accuracy = 100.00 * correct / len(data_loader.dataset) # 准确率
    if args.verbose: # 打印
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    return accuracy, test_loss

main_fed.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img


if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    # load dataset and split users
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)
    net_glob.train() # 模型设置为训练模式

    # copy weights 复制模型权重
    w_glob = net_glob.state_dict()

    # training
    loss_train = [] # 训练过程损失
    cv_loss, cv_acc = [], [] # 验证集损失,验证集准确率
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], [] # 验证准确率列表,模型列表

    if args.all_clients: # 对所有客户端全局聚合
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)] # 将全局模型权重复制num_users次放在w_locals中
    for iter in range(args.epochs): # 每个epoch循环中
        loss_locals = []   # 保存每个客户端的损失值
        if not args.all_clients: # 如过不是所有客户端,则创建空列表
            w_locals = []
        m = max(int(args.frac * args.num_users), 1) # 根据参数 args.frac(比例) 和 args.num_users(总用户数量) 计算每轮要选择的用户数量m
        idxs_users = np.random.choice(range(args.num_users), m, replace=False) # 随机选择m个不重复的用户索引,生成一个列表 idxs_users
        for idx in idxs_users:# 对于每个用户
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx]) # 进行本地训练
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        # update global weights
        w_glob = FedAvg(w_locals) # 联邦聚合

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob) #将更新后的全局模型权重 w_glob 加载到 net_glob 中

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals) # 平均损失
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

    # testing
    net_glob.eval()
    acc_train, loss_train = test_img(net_glob, dataset_train, args)
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))

main_nn.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms

from utils.options import args_parser
from models.Nets import MLP, CNNMnist, CNNCifar


def test(net_g, data_loader):# 用于模型在测试集上进行评估的函数
    # testing
    net_g.eval() # 模型设置为评估模式
    test_loss = 0   # 损失初始化
    correct = 0 # 正确数量初始化
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):  # 按批次索引从数据加载器中取数据以及对应的标签
        data, target = data.to(args.device), target.to(args.device) # 将数据和标签转移到device
        log_probs = net_g(data) # 将数据传入网络,得到预测的对数概率
        test_loss += F.cross_entropy(log_probs, target).item() # 计算交叉熵损失并累加
        y_pred = log_probs.data.max(1, keepdim=True)[1] #通过这段代码,可以得到log_probs张量中每一行的最大值对应的索引,即预测结果y_pred
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()  # 正确的类别累加

    test_loss /= len(data_loader.dataset)
    print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))

    return correct, test_loss


if __name__ == '__main__':
    # parse args
    args = args_parser() # 解析命令行参数,并将返回的参数存储在 args 变量中。
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')# 设备选择

    torch.manual_seed(args.seed) # 设置随机数种子,使用相同的种子将导致随机数生成器生成相同的随机数序列

    # load dataset and split users
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        img_size = dataset_train[0][0].shape # 将数据的第一个样本的形状存储在img_size中
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)
        img_size = dataset_train[0][0].shape
    else:
        exit('Error: unrecognized dataset')

    # build model 根据输入的情况构建不同的模型
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)

    # training
    # 定义优化器 optimizer,使用随机梯度下降(SGD)算法,将模型 net_glob 的参数传递给优化器
    optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)

    list_loss = [] # 存储每个epoch的损失
    net_glob.train() # 设置为训练模式
    for epoch in range(args.epochs):
        batch_loss = [] # 批次损失
        for batch_idx, (data, target) in enumerate(train_loader): # 训练数据集中安批次索引取数据和标签
            data, target = data.to(args.device), target.to(args.device) # 转移到device
            optimizer.zero_grad() # 优化器梯度置零
            output = net_glob(data) # 输出
            loss = F.cross_entropy(output, target) # 交叉熵损失
            loss.backward() # 反向传播计算梯度
            optimizer.step() # 根据计算的梯度更新模型的参数
            if batch_idx % 50 == 0:# 设置打印频率
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
            batch_loss.append(loss.item())
        loss_avg = sum(batch_loss)/len(batch_loss) # 批损失均值
        print('\nTrain loss:', loss_avg)
        list_loss.append(loss_avg)

    # plot loss
    # 绘制训练损失随着 epoch 变化的折线图,并保存为图片
    plt.figure()
    plt.plot(range(len(list_loss)), list_loss)
    plt.xlabel('epochs')
    plt.ylabel('train loss')
    plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))

    # testing
    if args.dataset == 'mnist':
        dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    else:
        exit('Error: unrecognized dataset')

    print('test on', len(dataset_test), 'samples')
    test_acc, test_loss = test(net_glob, test_loader)

  • 8
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值