以上是总结,然后看一下实际代码中怎么用
先看nn.CrossEntropyLoss()的用法
import numpy as np
A = range(12)
A = np.asarray(A)
A = A.reshape(4,3)
A = torch.from_numpy(A).float()
A.shape
A的形状为:torch.Size([4, 3])
label = np.ones((4,1))
label = torch.from_numpy(label).long()
label.shape
label的形状为:torch.Size([4, 1])
import torch.nn as nn
loss = nn.CrossEntropyLoss()
loss(A,label)
然后就会出现这个错误:
RuntimeError: 1D target tensor expected, multi-target not supported
意思大概是你给的目标值应该是1D,我传入的虽然是一个4*1的矩阵,但是它是一个二维矩阵。我得输入一个一维的标量,也就是形状为(n,),使用flatten()将矩阵变成标量。
修改后的完整代码
import numpy as np
A = range(12)
A = np.asarray(A)
A = A.reshape(4,3)
A = torch.from_numpy(A).float()
print(A.shape)
label = np.ones((4,1)).flatten()#修改的地方
label = torch.from_numpy(label).long()
print(label.shape)
import torch.nn as nn
loss = nn.CrossEntropyLoss()
loss(A,label)
输出为:torch.Size([4, 3])
torch.Size([4])
结果为:tensor(1.4076)
再来看nn.BCEWithLogitsLoss()
B = np.random.randn(4,1)
B = torch.from_numpy(B).float()
B.shape
torch.Size([4, 1])
目标值还是用之前的label
label = np.ones((4,1)).flatten()
label = torch.from_numpy(label).long()
label.shape
torch.Size([4])
import torch.nn as nn
loss = nn.BCEWithLogitsLoss()
loss(B,label)
但是会出现错误:
ValueError: Target size (torch.Size([4])) must be the same as input size (torch.Size([4, 1]))
诶?它又要我们将target的形状和输入的B保持一致了。
行,那我们把flatten再给它去了。
B = np.random.randn(4,1)
B = torch.from_numpy(B).float()
print(B.shape)
label = np.ones((4,1))
label = torch.from_numpy(label).long()
print(label.shape)
然后还有错:
RuntimeError: result type Float can’t be cast to the desired output type Long
那就再把long去掉,最终的正确代码:
B = np.random.randn(4,1)
B = torch.from_numpy(B).float()
print(B.shape)
label = np.ones((4,1))
label = torch.from_numpy(label)
print(label.shape)
import torch.nn as nn
loss = nn.BCEWithLogitsLoss()
loss(B,label)
输出:
torch.Size([4, 1])
torch.Size([4, 1])
结果:tensor(0.7415, dtype=torch.float64)
然后我又试了一下,把B和label都以1D标量的形式传入,发现也可以成功得到结果,也就是说只要两者的形状保持一致即可。
下面的写法也可。
B = np.random.randn(4,1).flatten()
B = torch.from_numpy(B).float()
print(B.shape)
label = np.ones((4,1))
label = torch.from_numpy(label).flatten()
print(label.shape)
import torch.nn as nn
loss = nn.BCEWithLogitsLoss()
loss(B,label)
import torch.nn as nn
loss = nn.BCEWithLogitsLoss()
loss(B,label)