上一篇文章中讲到如何将pascal voc2012数据集xml文件中的标签属性提取出来,本篇文章啊将会利用resnet相关网络实现多标签分类。
首先利用的相关代码为 https://github.com/AI-Chen/MultiLabelClassification,其中修改了src\Utils.py中的MyDataLoader的代码:
class MyDataLoader(data.Dataset):
def __init__(self, transform, trainval='train', data_path='../dataset', random_crops=0):
"""
Initialize the dataset. Inherited from torch.data.Dataset, __len__ and __getitem__ need to be implemented.
VOC(Labels only) tree:
--dataset root
|--train
| |--JPEGImages(dir)
| |--annotations.txt
|
|--test
|--JPEGImages(dir)
|--annotations.txt
:param transform: the transformation
:param data_path: the root of the datapath
:param random_crops:
"""
self.data_path = data_path
self.transform = transform
self.random_crops = random_crops
self.train_or_test = trainval
self.__init_classes()
self.names, self.labels = self.__dataset_info()
def __getitem__(self, index):
"""
This is the getitem func which enables enumerator. Implemented.
:param index: the index of the picture
:return: tuple (picture, its label(s))
"""
x = imread(os.path.join(self.data_path, self.train_or_test, 'JPEGImages', self.names[index] + '.jpg'),
mode='RGB')
x = Image.fromarray(x)
# Resize directly
x = x.resize((224, 224), Image.BILINEAR)
if self.random_crops == 0:
x = self.transform(x)
else:
crops = []
for i in range(self.random_crops):
crops.append(self.transform(x))
x = torch.stack(crops)
y = self.labels[index]
return x, y
def __len__(self):
"""
How many images are there. Implemented.
:return: length
"""
return len(self.names)
def __dataset_info(self):
"""
Generate names(np.array, with string elements) and labels(np.array, with array(number) elements).
The labels appears like this: [0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 0 0 0]
Those with value 1 means the object exists in this image
:return: names labels
"""
#修改部分 将上一篇文章得到的txt文件中的name属性转换成[0,1,0,,,,]格式
name_list=['person', 'bird', 'cat', 'cow',
'dog', 'horse', 'sheep', 'aeroplane', 'bicycle',
'boat', 'bus', 'car', 'motorbike',
'train', 'bottle', 'chair',
'diningtable', 'pottedplant', 'sofa', 'tvmonitor']
annotation_file = os.path.join(self.data_path, self.train_or_test, 'annotations.txt')
with open(annotation_file, 'r') as fp:
lines = fp.readlines()
names = []
labels = []
for line in lines:
# Name
names.append(line.strip('\n').split(' ')[0])
# Label
str_label = line.strip('\n').split(' ')[1:]
flag_label = np.zeros(self.num_classes,dtype=int)
for x in str_label:
y=name_list.index(x) #找到name的索引值
# num_label = [int(x) for x in str_label]
flag_label[y] = 1 #列表中对应位置修改为1
labels.append(np.array(flag_label))
return np.array(names), np.array(labels).astype(np.float32)
def __init_classes(self):
self.classes = ('person', 'bird', 'cat', 'cow',
'dog', 'horse', 'sheep', 'aeroplane', 'bicycle',
'boat', 'bus', 'car', 'motorbike',
'train', 'bottle', 'chair',
'diningtable', 'pottedplant', 'sofa', 'tvmonitor')
# name_list=[]
# [name_list.append(x) for x in self.classes[i][0] for i in range(len(self.classes))]
self.num_classes = len(self.classes)
self.class_to_ind = dict(zip(self.classes, range(self.num_classes)))
src\Train.py部分修改如下:
选用的网络为resnet50,我用resnet18网络训练在测试集的准确率仅为53%,利用resnet50训练的准确率可以达到93%。两个网络都训练了100轮且都使用了预训练网络。
parser = argparse.ArgumentParser(description='Train network on Pascal VOC 2012')
parser.add_argument('--pascal_path', default='../dataset/',type=str, help='Path to Pascal VOC 2012 folder')
parser.add_argument('--finetune', default=None, type=int, help='whether to use pytorch pretrained model and finetune')
parser.add_argument('--model', default='resnet50', type=str, help='which backbone network to use',
choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'])
parser.add_argument('--modelpath', default='../pretrained/resnet50-19c8e357.pth', type=str, help='pretrained model path')
parser.add_argument('--fc', default=None, type=int, help='load fc6 and fc7 from model')
parser.add_argument('--gpu', default=None, type=int, help='gpu id')
parser.add_argument('--epochs', default=100, type=int, help='max training epochs')
parser.add_argument('--iter_start', default=0, type=int, help='Starting iteration count')
parser.add_argument('--batch', default=4, type=int, help='batch size')
parser.add_argument('--checkpoint', default='../checkpoints/', type=str, help='checkpoint folder')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate for SGD optimizer')
parser.add_argument('--crops', default=10, type=int, help='number of random crops during testing')
训练结果为:设置的是每10轮保存一次模型。
测试图片如下:
测试结果如下:
另外,我开始训练的时候采用的是07-08作为测试集,09-12作为训练集,但是我发现10-12的数据集标注文件不完整,很多都只包含人这个类别,这样会造成类别不均衡,会导致训练结果精度很高,却没有什么意义,所以我将087-08作为训练集,共5096张,09作为测试集,共2722张图片,我还没有统计这些图片中每个类别有多少图片数量,所以接下来的事情是继续分析数据集以及采取其他优化方法。除此之外,训练图片不够也是造成训练结果精度不够高的原因,大家如果想自己做这方面的项目,可以采取本博客中的方法做一遍。
下一篇文章我会继续优化这个结果,敬请期待!!