FedProto代码复现

github clone到本地后,配置环境部分略 貌似我没有按照原论文FedProto: Federated Prototype Learning across Heterogeneous Clients标准环境来配置,原文的版本要求如下:

Requirements:
Python 3.6 or greater
PyTorch 1.6 or greater
Torchvision
Numpy 1.18.5

我的各个版本如下:

python 3.12.1
pytorch 2.2.0
numpy 1.26.4

目前来说大致没有问题 原文实验结果:

针对第一条指令 按照原文指令分别运行:

python federated_main.py --mode task_heter --dataset mnist --num_classes 10 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 100 --train_shots_max 110 --ld 1
python federated_main.py --mode task_heter --dataset mnist --num_classes 10 --num_users 20 --ways 4 --shots 100 --stdev 2 --rounds 100 --train_shots_max 110 --ld 1
python federated_main.py --mode task_heter --dataset mnist --num_classes 10 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 100 --train_shots_max 110 --ld 1

三个指令跑出来分别有如下结果:

mnist - ways = 3

mnist - ways = 4

mnist - ways = 5

对比原文:

MNIST 数据集
accurary/n_sizen=3n=4n=5
my_with proto98.3997.2296.57
the paper_with proto97.13+-0.3096.80+-0.4196.70+-0.29
可以看到与原文有一些微妙偏差 但总体来说达到甚至超过了原文的 accuary

针对第二条指令:

python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 120 --train_shots_max 110 --ld 1
python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 4 --shots 100 --stdev 2 --rounds 120 --train_shots_max 110 --ld 1
python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 120 --train_shots_max 110 --ld 1

在运行的时候并不顺利 似乎是不存在对应设置的路径

检查系统目录后发现确实不存在这样一个路径 于是我自己根据readme文件中的链接自己下载了by_class.zip,并创建了如上图所示的对应路径,于是运行成功,但训练过程中又出现如下问题:

经检查对应目录确实没有这个图片,猜测可能是循环范围过大(?)

阅读源码后发现可能是由于shell传入的train_shorts_max过大,这可能是由于数据集是我自行下载于作者使用的不完全相同

num_img = args.train_shots_max * args.num_users

我尝试将tran_shots_maxs由110修改为100,发现训练可以正常运行了!

python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 120 --train_shots_max 100 --ld 1

悲,在训练结束后发现,仍然报错:

发现是 解包的问题 我们找到对应的代码:

原代码:

    accurary/n_sizecc_list_l, acc_list_g = test_inference_new_het_lt(args, local_model_list, test_dataset, classes_list, user_groups_lt, global_protos)
    print('For all users (with protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(np.mean(acc_list_g),np.std(acc_list_g)))
    print('For all users (w/o protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(np.mean(acc_list_l), np.std(acc_list_l)))

修改为:

    acc_list_l, acc_list_g,loss_list = test_inference_new_het_lt(args, local_model_list, test_dataset, classes_list, user_groups_lt, global_protos)
    print('For all users (with protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(np.mean(acc_list_g),np.std(acc_list_g)))
    print('For all users (w/o protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(np.mean(acc_list_l), np.std(acc_list_l)))
    print('For all users (with protos), mean of proto loss is {:.5f}, std of test acc is {:.5f}'.format(np.mean(loss_list), np.std(loss_list)))

并用comm rounds=1来测试:

结果正常!

修改代码后 最终结果为:

femnist ways = 3

femnist ways = 4

在运行ways =5这条指令的时候:

python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 120 --train_shots_max 110 --ld 1

运行一轮(Global Round 0结束)后再次报错。。

IndexError: list index out of range

femnist.py中添加代码逻辑查看是不是数据索引越界了:

由输出可以知道,确实是对应数据下标溢出了

添加新的代码逻辑修复这个问题:

问题成功解决,在由Global Round 0跳转到Global Round 1时不再报错。

运行结果如下:

FEMNIST 数据集
accuracy/n_sizen=3n=4n=5
my with proto95.8494.4394.63
the paper with proto96.82+-1.7594.93+-1.6193.67+-2.23

针对第三条指令 分别运行:

python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 --local_bs 32
python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 4 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 --local_bs 32
python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 --local_bs 32

结果如下:

cifar10 - ways = 3

cifar10 - ways = 4

cifar10 - ways =5

对比原文:

CIFAR10 数据集
acc/n_sizen=3n=4n=5
my_with proto68.9054.8359.10
the paper_with proto84.49+-1.9779.12+-2.0377.08+-1.98

于原文精确度相差比较大,尝试在github原仓库寻找解决办法:

发现这个问题貌似是普遍存在?

找到了两条比较新的评论 问题仍未解决...

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值