DataFountain-天气以及时间分类:CNN多标签分类 准确率0.92
数据:
dataset返回: img 标签1 标签2 标签3
模型 backbone
+ fc1
+ fc2
+ fc3
def forward(self, x):
out = self.backbone(x)
# 同时完成类别1 和 类别2 分类
logits1 = self.fc1(out)
logits2 = self.fc2(out)
return logits1, logits2
训练:
# 模型训练
model.train()
for i, (x, y1, y2) in enumerate(train_loader): #图片x 标签 y1 y2
pred1, pred2 = model(x) # 模型返回对应,标签属性下的多分类头,加入标签计算损失函数
# 类别1 loss + 类别2 loss 为总共的loss ,多个损失就变成多个分类损失相加
loss = criterion(pred1, y1) + criterion(pred2, y2)
Train_Loss.append(loss.item())
loss.backward()
optimizer.step()
optimizer.clear_grad()
Train_ACC1.append((pred1.argmax(1) == y1.flatten()).numpy().mean())
Train_ACC2.append((pred2.argmax(1) == y2.flatten()).numpy().mean())