✨博客主页:王乐予🎈
✨年轻人要:Living for the moment(活在当下)!💪
🏆推荐专栏:【图像处理】【千锤百炼Python】【深度学习】【排序算法】
目录
在图像分类领域,可能会遇到需要确定对象的多个属性的场景。例如,这些可以是类别、颜色、大小等。与通常的图像分类相比,此任务的输出将包含 2 个或更多属性。
在本教程中,我们将重点讨论一个问题,即我们事先知道属性的数量。此类任务称为多输出分类。事实上,这是多标签分类的一种特例,还可以预测多个属性,但它们的数量可能因样本而异。
本文程序已解耦,可当做通用型多标签图像分类框架使用。
数据集下载地址:Fashion-Product-Images
完整代码:GitHub:Multi-Label-Image-Classification
😺一、数据集介绍
我们将使用时尚产品图片数据集。它包含超过 44 000 张衣服和配饰图片,每张图片有 9 个标签。
从 kaggle 上下载到数据集后解压可以一个文件夹和一个csv表格,分别是images
和styles.csv
。
其中images
里存放了数据集中所有的图片。
styles.csv
中写入了图片的相关信息,包括 id(图片名称)、gender(性别)、masterCategory(主要类别)、subCategory(二级类别)、articleType(服装类型)、baseColour(描述性颜色)、season(季节)、year(年份)、usage(使用说明)、productDisplayName(品牌名称)。
😺二、工程文件夹目录
工程文件夹目录如下,每个py文件具有不同的功能,这么写的好处是未来修改程序更加方便,而且每个py程序都没有很长。如果全部写到一个py程序里,则会显得很臃肿,修改起来也不轻松。
对每个文件的解释如下:
- checkpoints:存放训练的模型权重;
- datasets:存放数据集。并对数据集划分;
- logs:存放训练日志。包括训练、验证时候的损失与精度情况;
- option.py:存放整个工程下需要用到的所有参数;
- utils.py:存放各种函数。包括模型保存、模型加载和损失函数等;
- split_data.py:划分数据集;
- model.py:构建神经网络模型;
- train.py:训练模型;
- predict.py:评估训练模型。
😺三、option.py
import argparse
def get_args():
parser = argparse.ArgumentParser(description='ALL ARGS')
parser.add_argument('--device', type=str, default='cuda', help='cuda or cpu')
parser.add_argument('--start_epoch', type=int, default=0, help='start epoch')
parser.add_argument('--epochs', type=int, default=100, help='Total Training Times')
parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
parser.add_argument('--num_workers', type=int, default=0, help='number of processes to handle dataset loading')
parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam')
parser.add_argument('--datasets_path', type=str, default='./datasets/', help='Path to the dataset')
parser.add_argument('--image_path', type=str, default='./datasets/images', help='Path to the style image')
parser.add_argument('--original_csv_path', type=str, default='./datasets/styles.csv', help='Original csv file dir')
parser.add_argument('--train_csv_path', type=str, default='./datasets/train.csv', help='train csv file dir')
parser.add_argument('--val_csv_path', type=str, default='./datasets/val.csv', help='val csv file dir')
parser.add_argument('--log_dir', type=str, default='./logs/', help='log dir')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/', help='checkpoints dir')
parser.add_argument('--checkpoint', type=str, default='./checkpoints/2024-05-24_13-50/checkpoint-000002.pth', help='choose a checkpoint to predict')
parser.add_argument('--predict_image_path', type=str, default='./datasets/images/1163.jpg', help='show ground truth')
return parser.parse_args()
😺四、split_data.py
由于数据集的各个属性严重不均衡,为简单起见,在本教程中仅使用三个标签:gender、articleType 和 baseColour
import csv
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
from option import get_args
def save_csv(data, path, fieldnames=['image_path', 'gender', 'articleType', 'baseColour']):
with open(path, 'w', newline='') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
for row in data:
writer.writerow(dict(zip(fieldnames, row)))
if __name__ == '__main__':
args = get_args()
input_folder = args.datasets_path
output_folder = args.datasets_path
annotation = args.original_csv_path
all_data = []
with open(annotation) as csv_file:
reader = csv.DictReader(csv_file)
for row in tqdm(reader, total=reader.line_num):
img_id = row['id']
# only three attributes are used: gender articleType、baseColour
gender = row['gender']
articleType = row['articleType']
baseColour = row['baseColour']
img_name = os.path.join(input_folder, 'images', str(img_id) + '.jpg')
# Determine if the image exists
if os.path.exists(img_name):
# Check if the image is 80 * 60 size and if it is in RGB format
img = Image.open(img_name)
if img.size == (60, 80) and img.mode == "RGB":
all_data.append([img_name, gender, articleType, baseColour])
np.random.seed(42)
all_data = np.asarray(all_data)
# Randomly select 40000 data points
inds = np.random.choice(40000, 40000, replace=False)
# Divide training and validation sets
save_csv(all_data[inds][:32000], args.train_csv_path)
save_csv(all_data[inds][32000:40000], args.val_csv_path)
😺五、dataset.py
该代码实现了两个类,AttributesDataset
用于处理属性标签,FashionDataset
类继承自Dataset类,用于处理带有图片路径和属性标签的数据集。关键地方的解释在代码中已经进行了注释。
get_mean_and_std
函数用于获取数据集图像的均值与标准差
import csv
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset
from torchvision import transforms
from option import get_args
args = get_args()
mean = [0.85418772, 0.83673165, 0.83065592]
std = [0.25331535, 0.26539705, 0.26877365]
class AttributesDataset():
def __init__(self, annotation_path):
color_labels = []
gender_labels = []
article_labels = []
with open(annotation_path) as f:
reader = csv.DictReader(f)
for row in reader:
color_labels.append(row['baseColour'])
gender_labels.append(row['gender'])
article_labels.append(row['articleType'])
# Remove duplicate values to obtain a unique label set
self.color_labels = np.unique(color_labels)
self.gender_labels = np.unique(gender_labels)
self.article_labels = np.unique(article_labels)
# Calculate the number of categories for each label
self.num_colors = len(self.color_labels)
self.num_genders = len(self.gender_labels)
self.num_articles = len(self.article_labels)
# Create label mapping: Create two dictionaries: one from label ID to label name, and the other from label name to label ID.
# Mapping results:self.gender_name_to_id:{'Boys': 0, 'Girls': 1, 'Men': 2, 'Unisex': 3, 'Women': 4}
# Mappin