深度学习——(3)MSE踩到的坑
最近在做图像分类,今天在调试VGG16的代码,想要把后面的loss从交叉熵换成MSE,但是其中出现了bug!是我最不喜欢但是又避不开的东西
文章目录
part 1 兴高采烈改代码
start_time = time.time()
# 定义loss函数
loss_fn=torch.nn.MSELoss()
for epoch in range(num_epochs):
#训练
model.train()
for batch_idx, (features, targets) in enumerate(trainLoader):
features = features.to(DEVICE)
# print()
#
targets=torch.zeros(32, 4).scatter_(1, targets.unsqueeze(1), 1)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits = model(features)
print(logits.shape)
# cost = F.cross_entropy(logits, targets)
cost= loss_fn(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
### LOGGING
if not batch_idx % 50:
print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'
%(epoch+1, num_epochs, batch_idx,
len(trainLoader), cost))
将cost = F.cross_entropy(logits, targets)
改为loss_fn=torch.nn.MSELoss() + cost= loss_fn(logits, targets)
part 2 error
error1
RuntimeError: The size of tensor a (1000) must match the size of tensor b (128) at non-singleton dimension 1
error2
RuntimeError: The size of tensor a (4) must match the size of tensor b (128) at non-singleton dimension 1
error3
RuntimeError: The size of tensor a (4) must match the size of tensor b (32) at non-singleton dimension 1
part 3 心情低落查资料
找回信心1
对数字的敏感性,看到1000,我的第一反应是,我将VGG16的model加载以后,没有将后面的全连接层改成适合自己的4分类。(改了,没运行!)将代码重新运行了,变了变了,error变成了error2
找回信心2
看到128,batchsize,查了资料说MSEloss是每个batch都会产生一个,所以心中在想那肯定是batchsize的问题了,改改试试,改成32试试,出现了error3
找回信心3
想着可能真的用不了了,准备了充分的道理准备明天说服师兄,还是用交叉熵,资料都准备好了:
part 4 自己动手丰衣足食
但是自己不想认死理,输出的是连续值,为什么会用不了MSE,即使的分类问题,前期出来的也是概率值。索信打印一下,这个error一看就是维度的问题。
的确:
print(targets.shape)
print(logits.shape)
两个维度的确不同,打印targets的value试试
print(targets)
明白了,没有将target转换为one-hot类型的,好家伙!
转呗
targets=torch.zeros(batch_size, 4).scatter_(1, targets.unsqueeze(1), 1)
targets = targets.to(DEVICE)
OK啦!
注: 将数据转换为onehot类型targets=torch.zeros(batch_size, 4).scatter_(1, targets.unsqueeze(1), 1)
,其中的batchsize表示所有的样本数量,4表示分为4类。如果你的数据是n个样本分为m类,一样的转一下就OK,注意其中的target已经转换为torch.tensor类型。
可以啦!886,还有一个任务是将网络中的某一层的特征提取出来并进行GRAN-CAM热图化。明天在做。做一下今天的PPT