在使用DCGAN进行网络训练时,出现了以下报错:
RuntimeError: Found dtype Long but expected Float
出此报错的代码片段如下:
label = torch.full((b_size,), real_label, device=device)
# 将带有正样本的batch,输入到判别网络 中进行前向计算,得到结果放到变量output中
output = netD(real_cpu).view(-1)
# 计算loss
errD_real = criterion(output, label)
其原因在于将输出数据和标签值传入损失函数中的数据类型与需要的数据类型不匹配,需要的是float类型的数据,传入的是long类型的数据。
因此我们需要将传入的数据转换成float类型。
修改后的代码如下:
label = torch.full((b_size,), real_label, device=device)
# 将带有正样本的batch,输入到判别网络 中进行前向计算,得到结果放到变量output中
output = netD(real_cpu).view(-1)
#将传入的数据转换成float类型
output = output.to(torch.float32)
label = label.to(torch.float32)
# 计算loss
errD_real = criterion(output, label)
问题解决!