pytorch实现多标签分类

前言

多标签分类从属性上可以分为多属性单标签分类和单属性多标签分类两种,例如:

多属性单标签分类:如描述某件衣服的属性:属于什么样的类型(T恤、衬衫、西装等)、属于什么样的大小(大、中、小等)、属于什么样的领口(圆领、立领等),可以看出某件衣服可以有多个标签,但每种属性有且仅有一个标签,这种分类任务即多属性单标签分类;

单属性多标签分类:如描述某件衣服上的颜色:衣服上有何种颜色(红+黄、白+黑+黄等),这里只有一种属性,且属性内有多个标签,这种分类任务即单属性多标签分类。

以下内容为解决多属性单标签任务,提出了一种多头分类框架,将一个多标签分类任务,转换为多头单分类任务-----多属性共用一个特征提取器(backbone),多个分类头分别负责各自的分类任务。

注:

1.由于作者非科班出身,代码能力有限,还希望各位大佬多多指教,共同优化该框架

2.由于作者这边网络时好时坏,因此不提供github仓库,环境配置会在本文章中给出

数据集

作者使用的数据集格式如下:

dataset

--img

--train.txt

--val.txt

其中,标签文件的内容格式如下

图像名称 \t one-hot编码 \n

例如:

qwer.jpg\t0010000100000001\n

zxcv.jpg\t0010000100000001\n

..............................................

数据集制作

由于大家拿到的原始数据各有不同,作者的处理脚本也非通用,因此这里还需要大家自行编写数据处理脚本,如果有刚入门的同学代码编写困难,可以在评论区留言,作者会提供能力所及的帮助。

模型文件(model.py)

这里作者主要使用迁移学习完成backbone的设计,提供了CNN11、Mobilenet_v3_large(使用swish替换了h-swish)、Mobilenet_v2、Wide_resnet50_2、FixEfficientNetB3、EfficientNetB3模型,大家可以根据自己的需求自行选取或增加别的模型,至于为什么使用swish替换h-swish是因为作者需要进行模型部署,而部署工具不支持h-swish函数转换。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from efficientnet_pytorch import EfficientNet
import timm

import torch
import torch.nn as nn

#swish激活函数
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

#多任务损失函数
class MultiTaskLossModule_Dynamic(nn.Module):
    def __init__(self,class_info):
        super(MultiTaskLossModule_Dynamic, self).__init__()
        self.tasks = {task: "label_" + task for task in class_info.keys()}

    def _calculate_loss(self, output_key, label_key, net_output, ground_truth):
        """计算单个任务的损失"""
        return F.cross_entropy(net_output[output_key], ground_truth[label_key])

    def get_loss(self, net_output, ground_truth):
        # 定义任务及其对应的输出和标签键

        total_loss = 0
        individual_losses = {}

        for task, label_key in self.tasks.items():
            output_key = task  # 假设输出键与任务名称相同
            loss = self._calculate_loss(output_key, label_key, net_output, ground_truth)
            total_loss += loss
            individual_losses[task] = loss

        return total_loss, individual_losses 

#VGG11
class VGG11_Dynamic(MultiTaskLossModule_Dynamic):
    def __init__(self, img_bchw, class_info):
        super(VGG11_Dynamic, self).__init__(class_info)

        self.base_model = models.vgg11_bn(pretrained=True).features
        dummy_input = torch.zeros(img_bchw)
        with torch.no_grad():
            out = self.base_model(dummy_input)
        last_channel = out.size(1)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.bn = nn.BatchNorm1d(last_channel)

        # 动态创建分类头
        self.classifiers = nn.ModuleDict()
        for name, n_classes in class_info.items():
            self.classifiers[name] = nn.Linear(last_channel, n_classes)

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.bn(x)

        # 通过各自的分类头
        outputs = {}
        for name, classifier in self.classifiers.items():
            outputs[name] = classifier(x)

        return outputs

#MobileNetV3 Large
class Mobilenet_v3_large_ML_swish_Dynamic(MultiTaskLossModule_Dynamic):
    def __init__(self, img_bchw, class_info):
        super(Mobilenet_v3_large_ML_swish_Dynamic, self).__init__(class_info)

        
        self.base_model = models.mobilenet_v3_large(pretrained=True).features
        for name, module in self.base_model.named_children():
            if isinstance(module, torch.nn.Hardswish):  # Hardswish是MobilenetV3中的激活函数
                setattr(self.base_model, name, Swish())
            elif isinstance(module, torch.nn.Module):
                for name1, module1 in module.named_children():
                    if isinstance(module1, torch.nn.Hardswish):
                        setattr(module, name1, Swish())
                    elif isinstance(module1, torch.nn.Module):
                        for name2, module2 in module1.named_children():
                            if isinstance(module2, torch.nn.Hardswish):
                                setattr(module1, name2, Swish())
                            elif isinstance(module2, torch.nn.Module):
                                for name3, module3 in module2.named_children():
                                    if isinstance(module3, torch.nn.Hardswish):
                                        setattr(module2, name3, Swish())

        dummy_input = torch.zeros(img_bchw)
        with torch.no_grad():
            out = self.base_model(dummy_input)
        last_channel = out.size(1)  # Number of channels in the last feature map

        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # 动态创建分类头
        self.classifiers = nn.ModuleDict()
        for name, n_classes in class_info.items():
            self.classifiers[name] = nn.Linear(last_channel, n_classes)

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)

        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = torch.flatten(x, 1)
        # 通过各自的分类头
        outputs = {}
        for name, classifier in self.classifiers.items():
            outputs[name] = classifier(x)

        return outputs

#MobileNetV2
class Mobilenet_v2_ML_Dynamic(MultiTaskLossModule_Dynamic):
    def __init__(self, img_bchw, class_info):
        super(Mobilenet_v2_ML_Dynamic, self).__init__(class_info)

        self.base_model = models.mobilenet_v2().features  # take the model without classifier
        last_channel = models.mobilenet_v2().last_channel  # size of the layer before classifier

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        # 动态创建分类头
        self.classifiers = nn.ModuleDict()
        for name, n_classes in class_info.items():
            self.classifiers[name] = nn.Linear(last_channel, n_classes)

    def forward(self, x):
        x = self.base_model(x)
        x = self.pool(x)

        # reshape from [batch, channels, 1, 1] to [batch, channels] to put it into classifier
        x = torch.flatten(x, 1)

        # 通过各自的分类头
        outputs = {}
        for name, classifier in self.classifiers.items():
            outputs[name] = classifier(x)

        return outputs

#ResNet50
class Wide_resnet50_2_ML_Dynamic(MultiTaskLossModule_Dynamic):
    def __init__(self, img_bchw, class_info):
        super(Wide_resnet50_2_ML_Dynamic, self).__init__(class_info)

        # 加载预训练的WideResNet50_2模型
        self.wide_resnet = models.wide_resnet50_2(pretrained=True)

        # 替换原始的分类层
        num_ftrs = self.wide_resnet.fc.in_features
        self.wide_resnet.fc = nn.Identity()

        # 添加Batch Normalization和Dropout层
        self.bn = nn.BatchNorm1d(num_ftrs)
        self.dropout = nn.Dropout(0.5)

        # 动态创建分类头
        self.classifiers = nn.ModuleDict()
        for name, n_classes in class_info.items():
            self.classifiers[name] = nn.Linear(num_ftrs, n_classes)

    def forward(self, x):
        # 使用WideResNet50_2的特征提取部分
        x = self.wide_resnet(x)

        # 应用Batch Normalization和Dropout
        x = self.bn(x)
        x = self.dropout(x)

        # 通过各自的分类头
        outputs = {}
        for name, classifier in self.classifiers.items():
            outputs[name] = classifier(x)

        return outputs

#FixEfficientNetB3
class FixEfficientNetB3_Dynamic(MultiTaskLossModule_Dynamic):
    def __init__(self, img_bchw, class_info):
        super(FixEfficientNetB3_Dynamic, self).__init__(class_info)

        # 加载预训练的 FixEfficientNet-B3 模型
        self.backbone = timm.create_model('tf_efficientnet_b3', pretrained=True)
        self.original_classifier = self.backbone.classifier
        self.backbone.classifier = nn.Identity()  # 将原始分类器替换为身份映射

        # 修改分类器
        num_features = self.original_classifier.in_features

        self.bn = nn.BatchNorm1d(num_features)
        self.dropout = nn.Dropout(p=0.5)

        # 动态创建分类头
        self.classifiers = nn.ModuleDict()
        for name, n_classes in class_info.items():
            self.classifiers[name] = nn.Linear(num_features, n_classes)

    def forward(self, x):
        # 通过特征提取部分
        x = self.backbone.forward_features(x)
        x = torch.nn.functional.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)  # 全局平均池化
        x = self.bn(x)
        x = self.dropout(x)

        # 通过各自的分类头
        outputs = {}
        for name, classifier in self.classifiers.items():
            outputs[name] = classifier(x)

        return outputs

#EfficientNetB3
class EfficientNetB3_Dynamic(MultiTaskLossModule_Dynamic):
    def __init__(self, img_bchw, class_info):
        super(EfficientNetB3_Dynamic, self).__init__(class_info)

        self.backbone = EfficientNet.from_pretrained('efficientnet-b3')

        # 获取模型的原始分类器部分
        self.original_classifier = self.backbone._fc
        self.backbone._fc=nn.Identity()
        # 修改分类器
        num_features = self.original_classifier.in_features

        self.bn = nn.BatchNorm1d(num_features)
        self.dropout = nn.Dropout(p=0.5)

        # 动态创建分类头
        self.classifiers = nn.ModuleDict()
        for name, n_classes in class_info.items():
            self.classifiers[name] = nn.Linear(num_features, n_classes)

    def forward(self, x):
        # 通过特征提取部分
        x = self.backbone.extract_features(x)
        x = torch.nn.functional.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)  # 全局平均池化
        x = self.bn(x)
        x = self.dropout(x)

        # 通过各自的分类头
        outputs = {}
        for name, classifier in self.classifiers.items():
            outputs[name] = classifier(x)

        return outputs

数据集文件(dataset.py)

import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset
import torch

class MultiLabelDataset_Dynamic(Dataset):
    def __init__(self, txt_file, img_dir, transform=None, label_info=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels = []
        
        # 初始化标签列表
        self.labels = {name: [] for name in label_info.keys()}

        with open(txt_file, 'r') as file:
            for line in file:
                parts = line.strip().split('\t')
                img_name = parts[0]
                labels = np.array([int(x) for x in parts[1].split(',')])
                
                # 动态分配标签
                start_idx = 0
                for name, end_idx in label_info.items():
                    self.labels[name].append(labels[start_idx:end_idx])
                    start_idx = end_idx
                
                self.img_labels.append((img_name, labels))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, _ = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # 构建标签字典
        dict_data = {
            'img': image,
            'labels': {
                name: torch.tensor(self.labels[name][idx], dtype=torch.float32)
                for name in self.labels.keys()
            }
        }

        return dict_data

验证文件(test.py)

import argparse
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, balanced_accuracy_score
from tqdm.auto import tqdm

def checkpoint_load(model, name):
    print('Restoring checkpoint: {}'.format(name))
    model.load_state_dict(torch.load(name, map_location='cpu'))
    epoch = int(os.path.splitext(os.path.basename(name))[0].split('-')[1])
    return epoch
def validate_Dynamic(model, dataloader, logger, iteration, device, checkpoint=None, label_names=None):
    if checkpoint is not None:
        checkpoint_load(model, checkpoint)

    model.eval()
    with torch.no_grad():
        avg_loss = 0
        accuracies = {name: 0 for name in label_names}

        for i, batch in enumerate(tqdm(dataloader, desc='Validating')):
            img = batch['img']
            target_labels = batch['labels']
            target_labels = {t: target_labels[t].to(device) for t in target_labels}
            output = model(img.to(device))

            val_train, val_train_losses = model.get_loss(output, target_labels)
            avg_loss += val_train.item()

            batch_accuracies = calculate_metrics_Dynamic(output, target_labels,label_names)
            for name, batch_accuracy in zip(label_names, batch_accuracies.values()):
                accuracies[name] += batch_accuracy

    n_samples = len(dataloader)
    avg_loss /= n_samples
    for name in label_names:
        accuracies[name] /= n_samples

    # print('-' * 72)
    # print(f"Validation loss: {avg_loss:.4f}", end=', ')
    # for name in label_names:
    #     print(f"{name}: {accuracies[name]:.4f}", end=', ')
    # print("\n")

    for name in label_names:
        logger.add_scalar(f'val_loss', avg_loss, iteration)
        logger.add_scalar(f'val_accuracy_{name}', accuracies[name], iteration)

    model.train()
    return avg_loss, {"accuracy_" + name: accuracies[name] for name in label_names}

def calculate_metrics_Dynamic(output, target,label_names):
    # 初始化准确率字典
    accuracies = {}

    # 捕获警告
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for name in label_names:
            predicted = output[name].cpu()
            gt = target[f'label_{name}'].cpu()

            # 将预测值和真实值转换为 numpy 数组
            predicted_labels = np.argmax(predicted.detach().numpy(), axis=1)
            gt_labels = np.argmax(gt.numpy(), axis=1)

            # 计算准确率
            accuracy = balanced_accuracy_score(gt_labels, predicted_labels)
            accuracies[f'accuracy_{name}'] = accuracy

    return accuracies

训练文件(train.py)

在保存权重时,作者使用的是acc,大家可以根据自己的需要进行更改

import argparse
import os
from datetime import datetime

import torch
import torchvision.transforms as transforms
from dataset import MultiLabelDataset_Dynamic
from model import FixEfficientNetB3_Dynamic,EfficientNetB3_Dynamic,Mobilenet_v3_large_ML_swish_Dynamic,Mobilenet_v2_ML_Dynamic,Wide_resnet50_2_ML_Dynamic,VGG11_Dynamic
from test import validate_Dynamic, calculate_metrics_Dynamic
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from tqdm.auto import tqdm

def get_cur_time():
    return datetime.strftime(datetime.now(), '%Y-%m-%d-%H-%M')

def checkpoint_save_Dynamic(model, name, epoch,total_loss,val_accuracy):
    keys = list(val_accuracy.keys())
    file_name = "checkpoint-{:03d}-loss:{:.2f}-".format(epoch,total_loss)
    for key in val_accuracy.keys():
        file_name = file_name + key.split("_")[-1] + ":" + str(round(val_accuracy[key],2))
        if key != keys[-1]: file_name = file_name + "-"
    file_name = file_name + ".pth"
    f = os.path.join(name,file_name)
    torch.save(model, f)
    print('Saved checkpoint:', f)

def make_label_info(class_info):
    label_info={}
    sum = 0
    for key in class_info.keys():
        sum = sum + class_info[key]
        label_info["label_" + key] = sum
    return label_info

def make_label_names(class_info):
    return list(class_info.keys())

def make_accuracy_dict(class_info):
    return {"accuracy_" + key:0 for key in class_info.keys()}



if __name__ == '__main__':
    train_txt_file = '/dataset/classify_one_hot_train.txt' #训练集路径
    val_txt_file = '/dataset/classify_one_hot_test.txt' #验证集路径
    img_dir = '/dataset/img' #图像路径
    class_info={'your_classify_head1': 4,'your_classify_head2': 7,'your_classify_head3': 3,"your_classify_head4": 3} #key表示每个分类头的名称,value表示该分类头的类别数

    start_epoch = 1 #开始训练的epoch起始编号
    N_epochs = 100 #训练轮次
    batch_size = 128 #批大小
    num_workers = 8  # 多线程加载数据
    device = torch.device("cuda:7") #使用哪块显卡,暂时不支持多卡,后续会适配

    
    label_info = make_label_info(class_info)
    label_names = make_label_names(class_info)

    #数据增强参数
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    #加载数据集
    train_dataset = MultiLabelDataset_Dynamic(train_txt_file, img_dir, transform=transform,label_info=label_info)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dataset = MultiLabelDataset_Dynamic(val_txt_file, img_dir, transform=transform,label_info=label_info)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    #加载模型
    model = FixEfficientNetB3_Dynamic(img_bchw=(1,3,224,224),class_info=class_info).to(device)

    #设置优化其
    optimizer = torch.optim.AdamW(model.parameters())

    #日志及模型保存路径
    logdir = os.path.join('./logs-15/', get_cur_time())
    savedir = os.path.join('./checkpoints-15/', get_cur_time())
    os.makedirs(logdir, exist_ok=True)
    os.makedirs(savedir, exist_ok=True)
    logger = SummaryWriter(logdir)

    n_train_samples = len(train_dataloader)
    print("Starting training ...")

    max_total_acc = 0 #保存最优模型记录号
    for epoch in range(start_epoch, N_epochs + 1):
        start = time.time()
        total_loss = 0
        accuracy_dict = make_accuracy_dict(class_info)
        batch_index = 0
        for i, batch in enumerate(tqdm(train_dataloader, desc=f'Epoch {epoch}/{N_epochs}')):
            optimizer.zero_grad()

            img = batch['img']
            target_labels = batch['labels']
            target_labels = {t: target_labels[t].to(device) for t in target_labels}
            output = model(img.to(device))

            loss_train, losses_train = model.get_loss(output, target_labels)
            total_loss += loss_train.item()
            temp_accuracy_dict = calculate_metrics_Dynamic(output, target_labels,label_names=label_names)
            
            for key in temp_accuracy_dict.keys():
                accuracy_dict[key] += temp_accuracy_dict[key]

            loss_train.backward()
            optimizer.step()

        end = time.time()

        print("epoch {:4d},loss: {:.4f}".format(epoch, total_loss / n_train_samples),{key:round(accuracy_dict[key]/n_train_samples,2) for key in accuracy_dict.keys()})
        
        print("epoch:{:3d} use time: {:.4f}".format(epoch,end - start))
        
        logger.add_scalar('train_loss', total_loss / n_train_samples, epoch)

        val_loss,val_accuracy = validate_Dynamic(model, val_dataloader, logger, epoch, device,label_names=label_names)

        print("val_loss: {:.4f}".format(val_loss),{key:round(val_accuracy[key],2) for key in val_accuracy.keys()})


        if epoch==1:
            checkpoint_save_Dynamic(model, savedir, epoch,val_loss,val_accuracy)
        elif epoch>1 and epoch!=N_epochs:
            if max_total_acc<sum(val_accuracy.values()):
                max_total_acc = sum(val_accuracy.values())
                checkpoint_save_Dynamic(model, savedir, epoch,val_loss,val_accuracy)
        else:
            checkpoint_save_Dynamic(model, savedir, epoch,val_loss,val_accuracy)

转换文件(export.py)

作者仅编写了torch2onnx的代码

import onnxruntime
import torch
import numpy as np
import onnx
from torch.onnx import export
import onnxsim

def get_onnx_model_info(num_head):
    output_names=[]
    dynamic_axes={'input': {0: 'batch_size'}}
    for i in range(num_head):
        output_names.append("output"+str(i+1))
        dynamic_axes["output"+str(i+1)] = {0: 'batch_size'}
    return output_names,dynamic_axes
def model_convert_onnx(model,device,output_path,num_head):
    output_names,dynamic_axes = get_onnx_model_info(num_head)
    dummy_input = torch.randn(1, 3, 224, 224,device=device,dtype=torch.float32)  # 假设是图像数据
    export(model,
           dummy_input,
           output_path,
           verbose=False,
           input_names=['input'],
           output_names=output_names,
           dynamic_axes=dynamic_axes)

if __name__ == '__main__':
    # create model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # load model weights
    num_head = 4 #有几个分类头
    model_weight_path = "your_weight.pth"
    output_path = model_weight_path.replace(".pth", ".onnx")
    
    model = torch.load(model_weight_path).to(device)
    model.eval()

    # onnx模型输出到哪里去
    model_convert_onnx(model,device,output_path,num_head)
    # 加载 ONNX 模型
    onnx_model = onnx.load(output_path)

    # 使用 onnx-simplifier 简化模型
    model_simplified, check = onnxsim.simplify(onnx_model)
    onnx.save(model_simplified, output_path)

环境文件(requirements.txt)

torch==1.13.1

onnx==1.16.1

onnxruntime==1.18.1

onnxsim==0.4.36

opencv-python

efficientnet_pytorch==0.7.1

timm==1.0.9

pillow==10.3.0

tqdm==4.66.4

scikit-learn==1.5.1

tensorboard==2.17.0

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值