利用scikit-image库生成图像标签数据集

方法二:利用scikit-image库生成图像标签数据集

提示:此处独立使用图像库scikit-image。即仅用io读图和显示处理服装关键点数据集

安装OpenCV的时候,安装opencv_python:
pip install scikit-image
导入的时候:from skimage import io, transform, draw

服装关键点数据集下载:链接:https://pan.baidu.com/s/1A_UEaulqsz60OhC5BStA9g?pwd=hr47
提取码:hr47

数据集描述:pytorch生成图像标签数据集的三种方式–前言

Skimage模块常用子模块

Skimage模块常用子模块:
io用于图像读取、保存,显示图片和视频。color颜色空间变换。filters包括图像增强、边缘检测、排序滤波、自动阈值。
draw基于numpy数组图像绘制,线段、矩形、圆和文本。transform几何变换包括:旋转,拉伸,收缩等非回调函数。
Exposure曝光调整包括:强度、亮度、直方图均衡化。Feature特征检测与提取。
measure图像属性测量:相似性、等高线。segmentation图像分割。
restoration图像恢复。

生成 图像-关键点坐标标签 数据集

此例,服装类型和关键点图像-标签数据集,引入 io, transform, draw的函数模块进行处理。
数据集展示:(图像,坐标,类型)和只管图像显示。
在这里插入图片描述

代码:dataset_by_skimage.py

# 1.输入图像预处理,统一尺寸。
# 2.真实值ground truth变形,img的shape = (h,w,c),label的shape=(x, y, 是否存在和显隐)
# 3.返回一个数据发生器,img用于给模型做输入,label与输出做损失计算。
import os
import numpy as np
import pandas as pd
import torch
from skimage import io, transform, draw  # skimage是基于python开发的数字图片处理包。此处使用IO和transform--里面的变换均为不可回调函数
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class KeyPointsDataSet(Dataset):
    """服装关键点标记数据集"""

    def __init__(self, root_dir, image_set='train', transforms=None):
        """
        初始化数据集
        :param root_dir: 数据目录(.csv和images的根目录)
        :param image_set: train训练,val验证,test测试
        :param transforms(callable,optional):图像变换-可选
        标签数据文件格式为csv_file: 标签csv文件(内容:图像相对地址-category类型-标签coordination坐标)
        """
        self._imgset = image_set
        self._image_paths = []  # 用于存储图片地址列表
        self._labels = []  # 图片标签坐标群
        self._cates = []  # 标签:服装类别
        self._csv_file = os.path.join(root_dir, image_set + '.csv')  # csv标签文件地址
        self._categories = ['blouse', 'outwear', 'dress', 'trousers', 'skirt', ]
        self.root_dir = root_dir
        self._transform = transforms

        self.__getFileList()  # 获取数据(图像,坐标,类型)

    def __len__(self):
        return len(self._image_paths)

    def __getitem__(self, idx):
        img_id = self._image_paths[idx]
        img_id = os.path.join(self.root_dir, img_id)
        image = io.imread(img_id)  # (高,宽,通道数)= (h, w, c)
        imgSize = image.shape[0:2]  # 原始图像宽高

        label = np.asfortranarray(self._labels[idx])  # (x, y, 显隐)=(宽,高,显隐性)
        category = self._categories.index(self._cates[idx])  # 0,1,2,3,4

        if self._transform:
            image = self._transform(image)
        else:
            image = transform.resize(image, output_shape=(256, 256))  # 使用skimage库自带transform
        afterSize = image.shape[0:2]  # 缩放后图像的宽高
        bi = np.array((afterSize[1], afterSize[0])) / np.array((imgSize[1], imgSize[0]))
        label[:, 0:2] = label[:, 0:2] * bi

        return image, label, category

    def __getFileList(self):
        file_info = pd.read_csv(self._csv_file)
        self._image_paths = file_info.iloc[:, 0]  # 第一列,相对地址列
        self._cates = file_info.iloc[:, 1]  # 第二列,服装类型:blouse,trousers,skirt,dress,outwear
        if self._imgset == 'train':
            landmarks = file_info.iloc[:, 2:26].values  # panda中DataFrame数据的读取。第3-25列为坐标群,共24组坐标,
            for i in range(len(landmarks)):  # 处理坐标数据84_497_1 to [84,497,1]
                label = []
                for j in range(24):
                    plot = landmarks[i][j].split('_')
                    coor = []
                    for per in plot:
                        coor.append(int(per))
                    label.append(coor)
                self._labels.append(np.concatenate(label))
            self._labels = np.array(self._labels).reshape((-1, 24, 3))
        else:
            self._labels = np.ones((len(self._image_paths), 24, 3)) * (-1)


def showImageAndCoor(img, coords):
    for coor in coords:
        if coor[2] == -1:
            pass
        else:
            # print(coor)
            rr, cc = draw.circle(coor[1], coor[0], 4)
            draw.set_color(img, [rr, cc], [255, 0, 0])
    io.imshow(img)
    io.show()


if __name__ == "__main__":
    fashionDataset = KeyPointsDataSet(root_dir=r"E:/Datasets/Fashion/Fashion AI-keypoints_24/train/",
                                      image_set="train",
                                      )
    dataloader = DataLoader(dataset=fashionDataset, batch_size=4)  # 因为整个类继承的是torch的Dataset,此处返回的都是'torch.Tensor'
    for i_batch, data in enumerate(dataloader):
        img, label, category = data
        img, label, category = img.numpy(), label.numpy(), category.numpy()  # 'torch.Tensor'不能直接显示,需要转换程io能处理的numpy数组格式。
        print(img.shape, label.shape, category)
        showImageAndCoor(img[0], label[0])
        # break


注意事项

  1. io读图的数据结构也是为(h, w, c)=(高,宽,通道),坐标组是(宽x, 高y)。统一伸缩时注意对应。
  2. 本文输出数据集为了显示并没有对图像数组进行归一化或标准化操作,用的时候需要加上归一化。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柏常青

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值