add函数 pytorch_add-pytorch代码实现

运行的是GitHub关于GAN算法的最高赞代码https://github.com/corenel/pytorch-adda

这个代码直接运行是不能直接出结果的,需要进行以下修改:

大佬的安装环境是

Python 3.6

PyTorch 0.2.0

但是我在anaconda上搭建的pytorch0.2.0显示的是cuda版本报错

所以我换成了

torch1.7.0

torchvision 0.8.1

这里torch和torchvision有版本对应要求,需要注意

torch和anaconda安装教程见我另一篇文档。

修改方案

首先会有一些版本不兼容需要修改的提示,如data[0]变成item()等,按照提示修改就可以

这里主要介绍一些难以寻找的大的修改。

导入MNIST和UPSP数据集

从torchvision.datasets.MNIST下载即可,代码在datasets文件夹里mnist.py 和usps.py下载后的MNIST图像都是灰度图像,只有一个通道。所以运行原来的程序会报错:RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]。

报错原因:这是因为mnist图像都是灰度图像,只有一个通道,而上面的transforms.Normalize 却对三个通道都归一化了,这肯定会报错,所以只要像下面修改即可:

pre_process = transforms.Compose([transforms.ToTensor(),

transforms.Normalize(

(0.5,),(0.5,))])

需要注意的是mnist.py和usps.py这两个函数里面都需要这样修改

float->long的字符类型出错

天知道我找这个错误找了多久,修改为

pretrain.py

acc += pred_cls.eq(labels.data).long().cpu().sum().item()#增加了.item(),.long()

优化后的精度没有之前的精度高

torchvision版本要为0.2.0版本,0.2.1版本每次加载数据都减去了一次平均值

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值