联邦学习开山代码报错整理

参考:联邦学习开山之作代码解读与收获


最近学习联邦学习开山代码,运行上面汇总的代码时,遇到了一些警告或报错,遂记录下来。

一、UserWarning警告

        正常运行代码就会在初始发出警告。

报错信息:To copy construct from a tensor...

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).   y_support=torch.tensor(y_support,dtype=torch.int64)

原因:由于update.py下的datasplit类里torch.tensor函数为深拷贝,不记录历史更新数据,所以有可能导致意想不到的错误,故警告。

改正:将返回值改为下面as_tensor就不会报错,该函数会共享后面对象的历史数据。

return torch.as_tensor(image), torch.as_tensor(label)

二、ValueError错误

        当我将options里的数据集分布默认改成mnist的非独立同分布且非平均分配时,极大概率发生报错。

报错信息:batch_size should be a positive integer value, but got batch_size=0

  0%|          | 0/10 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "F:\1-offical\FLearn\src\federated_main.py", line 90, in <module>
    local_model = LocalUpdate(args=args, dataset=train_dataset,
  File "F:\1-offical\FLearn\src\update.py", line 39, in __init__
    self.trainloader, self.validloader, self.testloader = self.train_val_test(dataset, list(idxs))
  File "F:\1-offical\FLearn\src\update.py", line 54, in train_val_test
    validloader = DataLoader(DatasetSplit(dataset, idxs_val),
  File "C:\Users\lenovo\AppData\Roaming\Python\Python39\site-packages\torch\utils\data\dataloader.py", line 357, in __init__
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  File "C:\Users\lenovo\AppData\Roaming\Python\Python39\site-packages\torch\utils\data\sampler.py", line 232, in __init__
    raise ValueError("batch_size should be a positive integer value, "
ValueError: batch_size should be a positive integer value, but got batch_size=0

原因:我的理解是,由于验证集列表太短(仅有5个数据,batch/10=0),导致一个批次的训练都进行不了,这样即产生报错。归根结底是因为用户整个数据集就很短,导致验证集分到的太少。

改正:最直接的方法就是调整update里面的batch_size大小,这样使得每轮都有合适的训练数据。

#parser.add_argument('--local_ep', type=int, default=5,help="本地训练轮次E,默认为10轮")        
validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=int(len(idxs_val)/self.args.local_ep), shuffle=False)
testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=int(len(idxs_test)/self.args.local_ep), shuffle=False)

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值