今天在跑模型的时候,发现了以下的报错
RuntimeError: Found dtype Long but expected Float
这里的报错发生在BCELoss的部分
loss_fn = torch.nn.BCELoss()
logit = logit.view(logit.size()[0]*logit.size()[1],-1)
batch_label = batch_label.view(batch_label.size()[0]*batch_label.size()[1],-1)
crossentropyloss = loss_fn(logit,batch_label)
这里的batch_label类型为torch.float,而crossentropyloss要求batch_label必须为torch.int类型
因此这段代码需要修改一下:
loss_fn = torch.nn.BCELoss()
logit = logit.view(logit.size()[0]*logit.size()[1],-1)
batch_label = batch_label.view(batch_label.size()[0]*batch_label.size()[1],-1)
batch_label = batch_label.to(torch.float)
crossentropyloss = loss_fn(logit,batch_label)