从github clone到本地后,配置环境部分略 貌似我没有按照原论文FedProto: Federated Prototype Learning across Heterogeneous Clients标准环境来配置,原文的版本要求如下:
Requirements:
Python 3.6 or greater
PyTorch 1.6 or greater
Torchvision
Numpy 1.18.5
我的各个版本如下:
python 3.12.1
pytorch 2.2.0
numpy 1.26.4
目前来说大致没有问题 原文实验结果:
针对第一条指令 按照原文指令分别运行:
python federated_main.py --mode task_heter --dataset mnist --num_classes 10 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 100 --train_shots_max 110 --ld 1
python federated_main.py --mode task_heter --dataset mnist --num_classes 10 --num_users 20 --ways 4 --shots 100 --stdev 2 --rounds 100 --train_shots_max 110 --ld 1
python federated_main.py --mode task_heter --dataset mnist --num_classes 10 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 100 --train_shots_max 110 --ld 1
三个指令跑出来分别有如下结果:
mnist - ways = 3
mnist - ways = 4
mnist - ways = 5
对比原文:
MNIST 数据集
accurary/n_size | n=3 | n=4 | n=5 |
my_with proto | 98.39 | 97.22 | 96.57 |
the paper_with proto | 97.13+-0.30 | 96.80+-0.41 | 96.70+-0.29 |
可以看到与原文有一些微妙偏差 但总体来说达到甚至超过了原文的 accuary
针对第二条指令:
python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 120 --train_shots_max 110 --ld 1
python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 4 --shots 100 --stdev 2 --rounds 120 --train_shots_max 110 --ld 1
python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 120 --train_shots_max 110 --ld 1
在运行的时候并不顺利 似乎是不存在对应设置的路径
检查系统目录后发现确实不存在这样一个路径 于是我自己根据readme文件中的链接自己下载了by_class.zip,并创建了如上图所示的对应路径,于是运行成功,但训练过程中又出现如下问题:
经检查对应目录确实没有这个图片,猜测可能是循环范围过大(?)
阅读源码后发现可能是由于shell传入的train_shorts_max过大,这可能是由于数据集是我自行下载于作者使用的不完全相同
num_img = args.train_shots_max * args.num_users
我尝试将tran_shots_maxs由110修改为100,发现训练可以正常运行了!
python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 120 --train_shots_max 100 --ld 1
悲,在训练结束后发现,仍然报错:
发现是 解包的问题 我们找到对应的代码:
原代码:
accurary/n_sizecc_list_l, acc_list_g = test_inference_new_het_lt(args, local_model_list, test_dataset, classes_list, user_groups_lt, global_protos)
print('For all users (with protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(np.mean(acc_list_g),np.std(acc_list_g)))
print('For all users (w/o protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(np.mean(acc_list_l), np.std(acc_list_l)))
修改为:
acc_list_l, acc_list_g,loss_list = test_inference_new_het_lt(args, local_model_list, test_dataset, classes_list, user_groups_lt, global_protos)
print('For all users (with protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(np.mean(acc_list_g),np.std(acc_list_g)))
print('For all users (w/o protos), mean of test acc is {:.5f}, std of test acc is {:.5f}'.format(np.mean(acc_list_l), np.std(acc_list_l)))
print('For all users (with protos), mean of proto loss is {:.5f}, std of test acc is {:.5f}'.format(np.mean(loss_list), np.std(loss_list)))
并用comm rounds=1来测试:
结果正常!
修改代码后 最终结果为:
femnist ways = 3
femnist ways = 4
在运行ways =5这条指令的时候:
python federated_main.py --mode model_heter --dataset femnist --num_classes 62 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 120 --train_shots_max 110 --ld 1
运行一轮(Global Round 0结束)后再次报错。。
IndexError: list index out of range
在femnist.py中添加代码逻辑查看是不是数据索引越界了:
由输出可以知道,确实是对应数据下标溢出了
添加新的代码逻辑修复这个问题:
问题成功解决,在由Global Round 0跳转到Global Round 1时不再报错。
运行结果如下:
FEMNIST 数据集
accuracy/n_size | n=3 | n=4 | n=5 |
---|---|---|---|
my with proto | 95.84 | 94.43 | 94.63 |
the paper with proto | 96.82+-1.75 | 94.93+-1.61 | 93.67+-2.23 |
针对第三条指令 分别运行:
python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 3 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 --local_bs 32
python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 4 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 --local_bs 32
python federated_main.py --mode task_heter --dataset cifar10 --num_classes 10 --num_users 20 --ways 5 --shots 100 --stdev 2 --rounds 110 --train_shots_max 110 --ld 0.1 --local_bs 32
结果如下:
cifar10 - ways = 3
cifar10 - ways = 4
cifar10 - ways =5
对比原文:
CIFAR10 数据集
acc/n_size | n=3 | n=4 | n=5 |
---|---|---|---|
my_with proto | 68.90 | 54.83 | 59.10 |
the paper_with proto | 84.49+-1.97 | 79.12+-2.03 | 77.08+-1.98 |
于原文精确度相差比较大,尝试在github原仓库寻找解决办法:
发现这个问题貌似是普遍存在?
找到了两条比较新的评论 问题仍未解决...