基于parcal voc2012数据集的图像多标签分类实战

上一篇文章中讲到如何将pascal voc2012数据集xml文件中的标签属性提取出来,本篇文章啊将会利用resnet相关网络实现多标签分类。

首先利用的相关代码为 https://github.com/AI-Chen/MultiLabelClassification,其中修改了src\Utils.py中的MyDataLoader的代码:

class MyDataLoader(data.Dataset):
    def __init__(self, transform, trainval='train', data_path='../dataset', random_crops=0):
        """
        Initialize the dataset. Inherited from torch.data.Dataset, __len__ and __getitem__ need to be implemented.
        VOC(Labels only) tree:
        --dataset root
         |--train
         | |--JPEGImages(dir)
         | |--annotations.txt
         |
         |--test
           |--JPEGImages(dir)
           |--annotations.txt
        :param transform: the transformation
        :param data_path: the root of the datapath
        :param random_crops:
        """
        self.data_path = data_path
        self.transform = transform
        self.random_crops = random_crops
        self.train_or_test = trainval

        self.__init_classes()
        self.names, self.labels = self.__dataset_info()

    def __getitem__(self, index):
        """
        This is the getitem func which enables enumerator. Implemented.
        :param index: the index of the picture
        :return: tuple (picture, its label(s))
        """
        x = imread(os.path.join(self.data_path, self.train_or_test, 'JPEGImages', self.names[index] + '.jpg'),
                   mode='RGB')
        x = Image.fromarray(x)

        # Resize directly
        x = x.resize((224, 224), Image.BILINEAR)

        if self.random_crops == 0:
            x = self.transform(x)
        else:
            crops = []
            for i in range(self.random_crops):
                crops.append(self.transform(x))
            x = torch.stack(crops)

        y = self.labels[index]
        return x, y

    def __len__(self):
        """
        How many images are there. Implemented.
        :return: length
        """
        return len(self.names)

    def __dataset_info(self):
        """
        Generate names(np.array, with string elements) and labels(np.array, with array(number) elements).
        The labels appears like this: [0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 0 0 0]
        Those with value 1 means the object exists in this image
        :return: names labels
        """
		#修改部分 将上一篇文章得到的txt文件中的name属性转换成[0,1,0,,,,]格式
	 
        name_list=['person', 'bird', 'cat', 'cow',
                        'dog', 'horse', 'sheep', 'aeroplane', 'bicycle',
                        'boat', 'bus', 'car', 'motorbike',
                        'train', 'bottle', 'chair',
                        'diningtable', 'pottedplant', 'sofa', 'tvmonitor']

        annotation_file = os.path.join(self.data_path, self.train_or_test, 'annotations.txt')
        with open(annotation_file, 'r') as fp:
            lines = fp.readlines()

        names = []
        labels = []
        for line in lines:
            # Name
            names.append(line.strip('\n').split(' ')[0])

            # Label
            str_label = line.strip('\n').split(' ')[1:]
            flag_label = np.zeros(self.num_classes,dtype=int)
            for x in str_label:
                y=name_list.index(x)   #找到name的索引值

            # num_label = [int(x) for x in str_label]
            
                flag_label[y] = 1  #列表中对应位置修改为1

            labels.append(np.array(flag_label))

        return np.array(names), np.array(labels).astype(np.float32)

    def __init_classes(self):
        self.classes = ('person', 'bird', 'cat', 'cow',
                        'dog', 'horse', 'sheep', 'aeroplane', 'bicycle',
                        'boat', 'bus', 'car', 'motorbike',
                        'train', 'bottle', 'chair',
                        'diningtable', 'pottedplant', 'sofa', 'tvmonitor')
        # name_list=[]
        # [name_list.append(x) for x in self.classes[i][0] for i in range(len(self.classes))]
        self.num_classes = len(self.classes)
        self.class_to_ind = dict(zip(self.classes, range(self.num_classes)))

src\Train.py部分修改如下:
选用的网络为resnet50,我用resnet18网络训练在测试集的准确率仅为53%,利用resnet50训练的准确率可以达到93%。两个网络都训练了100轮且都使用了预训练网络。

parser = argparse.ArgumentParser(description='Train network on Pascal VOC 2012')
parser.add_argument('--pascal_path', default='../dataset/',type=str, help='Path to Pascal VOC 2012 folder')
parser.add_argument('--finetune', default=None, type=int, help='whether to use pytorch pretrained model and finetune')
parser.add_argument('--model', default='resnet50', type=str, help='which backbone network to use',
                    choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'])
parser.add_argument('--modelpath', default='../pretrained/resnet50-19c8e357.pth', type=str, help='pretrained model path')
parser.add_argument('--fc', default=None, type=int, help='load fc6 and fc7 from model')
parser.add_argument('--gpu', default=None, type=int, help='gpu id')
parser.add_argument('--epochs', default=100, type=int, help='max training epochs')
parser.add_argument('--iter_start', default=0, type=int, help='Starting iteration count')
parser.add_argument('--batch', default=4, type=int, help='batch size')
parser.add_argument('--checkpoint', default='../checkpoints/', type=str, help='checkpoint folder')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate for SGD optimizer')
parser.add_argument('--crops', default=10, type=int, help='number of random crops during testing')

训练结果为:设置的是每10轮保存一次模型。
在这里插入图片描述
测试图片如下:
在这里插入图片描述
测试结果如下:
在这里插入图片描述
另外,我开始训练的时候采用的是07-08作为测试集,09-12作为训练集,但是我发现10-12的数据集标注文件不完整,很多都只包含人这个类别,这样会造成类别不均衡,会导致训练结果精度很高,却没有什么意义,所以我将087-08作为训练集,共5096张,09作为测试集,共2722张图片,我还没有统计这些图片中每个类别有多少图片数量,所以接下来的事情是继续分析数据集以及采取其他优化方法。除此之外,训练图片不够也是造成训练结果精度不够高的原因,大家如果想自己做这方面的项目,可以采取本博客中的方法做一遍。

下一篇文章我会继续优化这个结果,敬请期待!!

  • 3
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
VOC2012数据集是一个常用的计算机视觉数据集,用于目标分类、检测和分割任务。根据引用\[1\],VOC2012数据集包含了训练集、验证集和训练集与验证集的图像信息。其中,Segmentation文件夹存放的是目标分割图像信息,包括train.txt(训练集1464个)、val.txt(验证集1449个)和trainval.txt(训练集+验证集2913个)。 根据引用\[2\],VOC2012数据集的文件夹结构包括Annotations、ImageSets、ActionLayout、Main和Segmentation。其中,Annotations文件夹存放的是目标的标注信息;ImageSets包含了不同任务的图像集合;ActionLayout存放的是动作布局相关的信息;Main存放的是分类、检测和分割任务的主要文件;Segmentation存放的是分割任务的图像和标注信息。 关于目标检测网络的训练流程,根据引用\[3\],大致包括以下步骤: 1. 设置各种超参数,如学习率、批大小等。 2. 定义数据加载模块,用于加载训练数据。 3. 定义网络模型,用于目标检测。 4. 定义损失函数,用于衡量预测结果与真实标签之间的差异。 5. 定义优化器,如Adam或SGD,用于更新网络参数。 6. 遍历训练数据,进行预测、计算损失和反向传播更新参数。 7. 训练过程中可以打印损失值等信息进行监控。 8. 保存训练好的模型。 以上是关于VOC2012数据集和目标检测网络训练流程的简要介绍。 #### 引用[.reference_title] - *1* [PASCAL VOC2012数据集分析](https://blog.csdn.net/One2332x/article/details/121915764)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [[ 数据集 ] VOC 2012 数据集介绍](https://blog.csdn.net/weixin_45084253/article/details/124332044)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [一、目标检测入门VOC2012](https://blog.csdn.net/qq_56551150/article/details/126508127)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值