深度学习——(3)MSE踩到的坑

深度学习——(3)MSE踩到的坑

最近在做图像分类,今天在调试VGG16的代码,想要把后面的loss从交叉熵换成MSE,但是其中出现了bug!是我最不喜欢但是又避不开的东西

part 1 兴高采烈改代码

start_time = time.time()
# 定义loss函数
loss_fn=torch.nn.MSELoss()
for epoch in range(num_epochs):
    #训练
    model.train()
    for batch_idx, (features, targets) in enumerate(trainLoader):
        
        features = features.to(DEVICE)
       # print()
        #
        targets=torch.zeros(32, 4).scatter_(1, targets.unsqueeze(1), 1)
        targets = targets.to(DEVICE)
        
        ### FORWARD AND BACK PROP
        logits = model(features)
        print(logits.shape)
#        cost = F.cross_entropy(logits, targets)
        cost= loss_fn(logits, targets)
        optimizer.zero_grad()
        
        cost.backward()
        
        ### UPDATE MODEL PARAMETERS
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
                   %(epoch+1, num_epochs, batch_idx, 
                     len(trainLoader), cost))

cost = F.cross_entropy(logits, targets)改为loss_fn=torch.nn.MSELoss() + cost= loss_fn(logits, targets)

part 2 error

error1
RuntimeError: The size of tensor a (1000) must match the size of tensor b (128) at non-singleton dimension 1
error2
RuntimeError: The size of tensor a (4) must match the size of tensor b (128) at non-singleton dimension 1
error3
RuntimeError: The size of tensor a (4) must match the size of tensor b (32) at non-singleton dimension 1

part 3 心情低落查资料

找回信心1

对数字的敏感性,看到1000,我的第一反应是,我将VGG16的model加载以后,没有将后面的全连接层改成适合自己的4分类。(改了,没运行!)将代码重新运行了,变了变了,error变成了error2

找回信心2

看到128,batchsize,查了资料说MSEloss是每个batch都会产生一个,所以心中在想那肯定是batchsize的问题了,改改试试,改成32试试,出现了error3

找回信心3

想着可能真的用不了了,准备了充分的道理准备明天说服师兄,还是用交叉熵,资料都准备好了:
在这里插入图片描述
在这里插入图片描述

part 4 自己动手丰衣足食

但是自己不想认死理,输出的是连续值,为什么会用不了MSE,即使的分类问题,前期出来的也是概率值。索信打印一下,这个error一看就是维度的问题。
的确:

print(targets.shape)
print(logits.shape)

在这里插入图片描述
两个维度的确不同,打印targets的value试试

print(targets)

在这里插入图片描述明白了,没有将target转换为one-hot类型的,好家伙!
转呗

targets=torch.zeros(batch_size, 4).scatter_(1, targets.unsqueeze(1), 1)
targets = targets.to(DEVICE)

OK啦!

注: 将数据转换为onehot类型targets=torch.zeros(batch_size, 4).scatter_(1, targets.unsqueeze(1), 1),其中的batchsize表示所有的样本数量,4表示分为4类。如果你的数据是n个样本分为m类,一样的转一下就OK,注意其中的target已经转换为torch.tensor类型。
在这里插入图片描述
可以啦!886,还有一个任务是将网络中的某一层的特征提取出来并进行GRAN-CAM热图化。明天在做。做一下今天的PPT

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柚子味的羊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值