联邦学习non-iid用户数据采样python实现

联邦学习non-iid用户数据采样python实现


问题描述

联邦学习non-iid用户数据采样

解决办法

直接上代码

import os
from scipy.io import loadmat
import numpy as np
from collections import Counter
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from datasets.SequenceDatasets import dataset
from datasets.sequence_aug import *
from tqdm import tqdm
from options import args_parser


args = args_parser()

def get_files(root, N):
    '''
    This function is used to generate the final training set and test set.
    root:The location of the data set
    '''

    dataname = {0: ["97.mat", "105.mat", "118.mat", "130.mat", "169.mat", "185.mat", "197.mat", "209.mat", "222.mat",
                    "234.mat"],  # 1797rpm
                1: ["98.mat", "106.mat", "119.mat", "131.mat", "170.mat", "186.mat", "198.mat", "210.mat", "223.mat",
                    "235.mat"],  # 1772rpm
                2: ["99.mat", "107.mat", "120.mat", "132.mat", "171.mat", "187.mat", "199.mat", "211.mat", "224.mat",
                    "236.mat"],  # 1750rpm
                3: ["100.mat", "108.mat", "121.mat", "133.mat", "172.mat", "188.mat", "200.mat", "212.mat", "225.mat",
                    "237.mat"]}  # 1730rpm

    datasetname = ["12k Drive End Bearing Fault Data", "12k Fan End Bearing Fault Data",
                   "48k Drive End Bearing Fault Data",
                   "Normal Baseline Data"]
    label = [i for i in range(0, 10)]

    data = [[] for _ in range(args.num_classes)]
    lab = [[] for _ in range(args.num_classes)]

    m = int(N[0])
    num = []
    for n in tqdm(range(len(dataname[m]))):
        if n==0:
           path1 =os.path.join(root,datasetname[3], dataname[m][n]).replace("\\", "/")
        else:
            path1 = os.path.join(root,datasetname[0], dataname[m][n]).replace("\\", "/")
        data1, lab1 = data_load(path1,dataname[m][n],label=label[n])
        # data += data1
        # lab += lab1
        data[n].append(data1)
        lab[n].append(lab1)
        num.append(len(data1))

    return [data, lab],num


def data_load(filename, axisname, label):
    '''
    This function is mainly used to generate test data and training data.
    filename:Data location
    axisname:Select which channel's data,---->"_DE_time","_FE_time","_BA_time"
    '''
    signal_size = 1024
    axis = ["_DE_time", "_FE_time", "_BA_time"]

    datanumber = axisname.split(".")
    if eval(datanumber[0]) < 100:
        realaxis = "X0" + datanumber[0] + axis[0]
    else:
        realaxis = "X" + datanumber[0] + axis[0]
    fl = loadmat(filename)[realaxis]
    data = []
    lab = []
    start, end = 0, signal_size
    while end <= fl.shape[0]:
        data.append(fl[start:end])
        lab.append(label)
        start += signal_size
        end += signal_size

    return data, lab

def del_file(path_data):
    for i in os.listdir(path_data) :
        file_data = path_data + "\\" + i
        if os.path.isfile(file_data) == True:
            os.remove(file_data)
        else:
            del_file(file_data)

def get_dataset(args):
    rootpath = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))).replace("\\", "/")
    data_dir = os.path.join(rootpath,"data_preprocessing", "cwru").replace("\\", "/")
    output_dir = os.path.join(rootpath,"data").replace("\\", "/")
    os.makedirs(output_dir, exist_ok=True)
    task = f'{args.data_number}'
    print('task:', task)
    normlizetype = 'mean - std'
    delpath = os.path.join(rootpath,"data", args.dataset).replace("\\", "/")
    del_file(delpath)

    data_transforms = {
        'train': Compose([
            Reshape(),
            Normalize(normlizetype),
            # RandomAddGaussian(),
            # RandomScale(),
            # RandomStretch(),
            # RandomCrop(),
            Retype(),
            # Scale(1)
        ]),
        'val': Compose([
            Reshape(),
            Normalize(normlizetype),
            Retype(),
            # Scale(1)
        ])
    }

    #get source train and val
    list_data,num = get_files(data_dir, task)

    args.ways = int(args.num_classes * args.degree_noniid)
    while True:
        local_lb = []
        names = locals()
        for i in range(args.num_users):
            population = args.num_classes
            step = 1
            sample = [element for element in range(1, population, step)]
            names[f'loclb{i}'] = random.sample(range(0,args.num_classes),args.ways) #产生n--m之间的k个整数
            local_lb.append(names[f'loclb{i}'])
        locallb = [token for st in local_lb for token in st]
        l = dict(Counter(locallb))
        if len(l) == 10 :
            break

    # c1 = l[1]
    # label_begin = {}
    trclasslength = []
    teclasslength = []
    for i in range(args.num_classes):
        # ll = l[i]
        trlength = int(num[i] * 0.8 // l[i])
        telength = int(num[i] * 0.2 // l[i])
        trclasslength.append(trlength)
        teclasslength.append(telength)

    names = locals()
    for j in range(len(list_data[0])):
        # print(list_data[0][j])
        names[f'tedata_pd{j}'] = pd.DataFrame({"data": list_data[0][j][0], "label": list_data[1][j][0]})
        names[f'trdata_pd{j}'] = names[f'tedata_pd{j}'].iloc[:int(num[j]*0.8)]
        names[f'tedata_pd{j}'].drop(names[f'tedata_pd{j}'].index[0:int(num[j]*0.8)], inplace=True)


    # classes_list = []
    userdatapath = []
    names = locals()
    for i in range(args.num_users):
        trconcat = []
        teconcat = []

        for each_class in local_lb[i]:
            names[f'train_pd{each_class}'] = names[f'trdata_pd{each_class}'].iloc[:trclasslength[each_class]]
            names[f'trdata_pd{each_class}'].drop(names[f'trdata_pd{each_class}'].index[0:trclasslength[each_class]], inplace=True)
            names[f'test_pd{each_class}'] = names[f'tedata_pd{each_class}'].iloc[:teclasslength[each_class]]
            names[f'tedata_pd{each_class}'].drop(names[f'tedata_pd{each_class}'].index[0:teclasslength[each_class]], inplace=True)
            # test_pd = train_test_split(names[f'data_pd{j}'], test_size=0.2, random_state=40)
            trconcat.append(names[f'train_pd{each_class}'])
            teconcat.append(names[f'test_pd{each_class}'])
        names[f'trdf{i}'] = pd.concat(trconcat)
        names[f'tedf{i}'] = pd.concat(teconcat)

        xtrain = names[f'trdf{i}']['data'].values
        ytrain = names[f'trdf{i}']['label'].values
        xtest = names[f'tedf{i}']['data'].values
        ytest = names[f'tedf{i}']['label'].values

        dat_dict = dict()
        # X_train = X_train.permute(0, 2, 1)
        dat_dict["samples"] = torch.tensor([item for item in xtrain])
        dat_dict["samples"] = dat_dict["samples"].permute(0, 2, 1)
        dat_dict["labels"] = torch.from_numpy(ytrain)
        dir = os.path.join(output_dir, args.dataset, f"{args.data_number}",f"user{i}").replace("\\", "/")
        os.makedirs(dir, exist_ok=True)
        torch.save(dat_dict, os.path.join(dir,"train.pt").replace("\\", "/"))

        dat_dict = dict()
        dat_dict["samples"] = torch.tensor([item for item in xtest])
        dat_dict["samples"] = dat_dict["samples"].permute(0, 2, 1)
        dat_dict["labels"] = torch.from_numpy(ytest)
        os.makedirs(dir, exist_ok=True)
        torch.save(dat_dict, os.path.join(dir,"test.pt").replace("\\", "/"))

        realpath = os.path.join(rootpath,"data", args.dataset, f"{args.data_number}",f"user{i}").replace("\\", "/")

        userdatapath.append(realpath)

    return userdatapath,local_lb

if __name__ == '__main__':
    args = args_parser()
    get_dataset(args)
    print('finish!')

torch.cuda.empty_cache()
  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

甜度超标°

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值