今天在跑模型的时候,发现了以下的报错
RuntimeError: Found dtype Long but expected Float
代码部分只是将crossentropyloss换成了BCELoss,那么这里的报错就发生在BCELoss的部分
loss_fn = torch.nn.BCELoss()
logit = logit.view(logit.size()[0]*logit.size()[1],-1)
batch_label = batch_label.view(batch_label.size()[0]*batch_label.size()[1],-1)
crossentropyloss = loss_fn(logit,batch_label)
crossentropyloss要求batch_label必须为torch.int类型,这里BCELoss的batch_label类型为torch.float
因此这段代码需要修改一下:
loss_fn = torch.nn.BCELoss()
logit = logit.view(logit.size()[0]*logit.size()[1],-1)
batch_label = batch_label.view(batch_label.size()[0]*batch_label.size()[1],-1)
batch_label = batch_label.to(torch.float)
crossentropyloss = loss_fn(logit,batch_label)
————————————————
来自:https://blog.csdn.net/znevegiveup1/article/details/124778676