问题一:数据集下载失败
解决方案:
通过百度网盘手动下载文件,网盘链接 密码:locq。将文件拷贝到\raw下(无需解压)
mnist数据集下载的代码如下:
mnist = datasets.MNIST(
root='./data/',
train=True,
transform=img_transform,
download=True
)
修改mnist.py
(改之前)
resources = [
("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
]
改之后:换成如下自己的mnist下载存放的路径
resources = [
("file:///E:\\pycharm-workspace\\untitled\\data\\MNIST\\raw\\train-images-idx3-ubyte.gz",None),
("file:///E:\\pycharm-workspace\\untitled\\data\\MNIST\\raw\\train-labels-idx1-ubyte.gz",None),
("file:///E:\\pycharm-workspace\\untitled\\data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz",None),
("file:///E:\\pycharm-workspace\\untitled\\data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz",None)
]
问题二:输入输出不匹配
ValueError: Using a target size (torch.Size([128])) that is different to the input size (torch.Size([128, 1])) is deprecated. Please ensure they have the same size.
解决
real_label = Variable(torch.ones(num_img)).cuda() # 定义真实的图片label为1
fake_label = Variable(torch.zeros(num_img)).cuda() # 定义假的图片的label为0
real_out = D(real_img) # 将真实图片放入判别器中
# print(real_out.size()) 128*1
# print(real_label.size()) 128
定位到损失函数:即real_out([128,1])和real_label([128])不匹配
d_loss_real = criterion(real_out, real_label) # 得到真实图片的loss