多标签分类,数据,模型,训练输入输出

DataFountain-天气以及时间分类:CNN多标签分类 准确率0.92

数据:
	dataset返回: img   标签1 标签2 标签3

模型 backbone 
              + fc1
              + fc2
              + fc3
      
	    def forward(self, x):
	        out = self.backbone(x)
	        # 同时完成类别1 和 类别2 分类
	        logits1 = self.fc1(out)
	        logits2 = self.fc2(out)
	        return logits1, logits2
训练:
    # 模型训练
    model.train()
    for i, (x, y1, y2) in enumerate(train_loader): #图片x  标签  y1 y2
        pred1, pred2 = model(x)   # 模型返回对应,标签属性下的多分类头,加入标签计算损失函数

        # 类别1 loss + 类别2 loss 为总共的loss  ,多个损失就变成多个分类损失相加
        loss = criterion(pred1, y1) + criterion(pred2, y2)
        Train_Loss.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()

        Train_ACC1.append((pred1.argmax(1) == y1.flatten()).numpy().mean())
        Train_ACC2.append((pred2.argmax(1) == y2.flatten()).numpy().mean())
        
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值