Pytorch实现gan

问题一:数据集下载失败

解决方案:

通过百度网盘手动下载文件,网盘链接 密码:locq。将文件拷贝到\raw下(无需解压)
在这里插入图片描述
mnist数据集下载的代码如下:

mnist = datasets.MNIST(
    root='./data/', 
    train=True, 
    transform=img_transform, 
    download=True
)

修改mnist.py
(改之前)

resources = [
        ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
        ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
     ]

改之后:换成如下自己的mnist下载存放的路径

 resources = [
        ("file:///E:\\pycharm-workspace\\untitled\\data\\MNIST\\raw\\train-images-idx3-ubyte.gz",None),
        ("file:///E:\\pycharm-workspace\\untitled\\data\\MNIST\\raw\\train-labels-idx1-ubyte.gz",None),
        ("file:///E:\\pycharm-workspace\\untitled\\data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz",None),
        ("file:///E:\\pycharm-workspace\\untitled\\data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz",None)
    ]

问题二:输入输出不匹配

ValueError: Using a target size (torch.Size([128])) that is different to the input size (torch.Size([128, 1])) is deprecated. Please ensure they have the same size.

解决

 real_label = Variable(torch.ones(num_img)).cuda()  # 定义真实的图片label为1  
 fake_label = Variable(torch.zeros(num_img)).cuda()  # 定义假的图片的label为0  
 real_out = D(real_img)  # 将真实图片放入判别器中
 # print(real_out.size()) 128*1
 # print(real_label.size())  128
 定位到损失函数:即real_out([128,1])和real_label([128])不匹配
 d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss
 

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值