联邦学习FedAvg自编写代码

联邦学习中,联邦平均算法获得了很大的使用空间,因此常常被用于进行同步训练操作不多废话了,以下为Fedavg代码由于使用场景为NonIID场景,因此我使用了别人的一个MNIST数据集自定义的代码(见附录)FedAvg代码如下,功能具体看注释工作环境:python3.8.5 + pytorch(无cuda)# coding: utf-8# In[1]:import argparseimport torchimport osimport torch.nn as nnimport tor
摘要由CSDN通过智能技术生成

联邦学习中,联邦平均算法获得了很大的使用空间,因此常常被用于进行同步训练操作
不多废话了,以下为Fedavg代码
由于使用场景为NonIID场景,因此我使用了别人的一个MNIST数据集自定义的代码(见附录)
FedAvg代码如下,功能具体看注释
工作环境:python3.8.5 + pytorch(无cuda)
divergence模块可直接删除

# coding: utf-8

# In[1]:


import argparse
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
from PIL import Image
import torch
import copy
import pandas as pd
import random
import time
import sys
import re
import matplotlib.pyplot as plt
#import divergence

name = str(sys.argv[0])

# In[2]:

home_path = "./"
class MyDataset(torch.utils.data.Dataset): #创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
    def __init__(self,root,data,label,transform=None, target_transform=None): #初始化一些需要传入的参数
        super(MyDataset,self).__init__()
        imgs = []                      #创建一个名为img的空列表,一会儿用来装东西
        self.img_route = root
        for i in range(len(data)):
            imgs.append((data[i],int(label[i])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
 
    def __getitem__(self, index):    #这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
        route = self.img_route + str(label) + "/" + fn
        img = Image.open(route) #按照path读入图片from PIL import Image # 按照路径读取图片
        if self.transform is not None:
            img = self.transform(img) #是否进行transform
        return img,label  #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
 
    def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.imgs)


# In[3]:


filePath = home_path + 'data/MNIST/image_turn/'
train_data = []
train_label = []
for i in range(10):
    train_data.append(os.listdir(filePath+str(i)))
    train_label.append([i]*len(train_data[i]))
filePath = home_path + 'data/MNIST/image_test_turn/'
test_data = []
test_label = []
for i in range(10):
    test_data.append(os.listdir(filePath+str(i)))
    test_label.append([i]*len(test_data[i]))
test_ori = []
test_label_ori = []
for x in range(10):
    test_ori += test_data[x]
    test_label_ori += test_label[x]
test_data=MyDataset(home_path + "data/MNIST/image_test_turn/",test_ori,test_label_ori, transform=transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64)


# In[4]:


class
  • 2
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值