方法三:利用PIL.Image生成图像标签数据集
提示:提示:此处独立使用图像库pillow进行读图和显示
数据集基本情况描述在pytorch生成图像标签数据集的三种方式–前言部分。
PIL安装: pip install pillow
导入应用: from PIL import Image
Pillow官方使用手册:https://pillow.readthedocs.io/en/latest/index.html
需要服装-关键点
数据集下载可留言。
Image常用函数
Image包常用函数模块:
im = Image.open(file
# 读图
Image.open(file).convert()
#颜色变换
Image.open(file).show()
# 显示
Image.open(file).save()
# 保存
Image.open(file).thumbnail()
#创建缩略图
im.crop(box)
# 复制区域
im.paste(region, box)
# 黏贴区域
im.resize()
和im.rotate()
# 调整尺寸和旋转
图像滤波:ImageFilter from PIL import ImageFilter
im.filter(ImageFilter.BLUR)
# 图像模糊
图像模糊BLUR
,边缘CONTOUR
,细节DETAIL
,边缘增强EDGE_ENHANCE
,EDGE_ENHANCE_MORE
,EMBOSS
,FIND_EDGES
,图像锐化SHARPEN
,图像平滑SMOOTH
,SMOOTH_MORE
服装类型和关键点图像-标签数据集
此例,服装类型和关键点图像-标签数据集,引入了PIL.Image
模块。
数据集结果展示:(图像,坐标,类型)
提示:image读图,制作生成器,并展示图像数组数据
代码:dataset_by_PIL.py
# -*- coding: utf-8 -*-
# @Time : 2022/5/19 15:48
# @Author : Hyan Tan
# @File : dataset_by_PIL.py
# 1.输入图像预处理,统一尺寸。
# 2.真实值ground truth变形,img的shape = (c, h, w),label的shape=(x, y, 是否存在和显隐)
# 3.返回一个数据发生器,img用于给模型做输入,label与输出做损失计算。
import os
import numpy as np
import pandas as pd
import torch
from PIL import Image # PIL库的Image包是基于python开发的数字图片处理包。
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms # 使用torchvision中的变换
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 训练和验证集采用transform做变换
transform = transforms.Compose([
transforms.Resize([256, 256]), # 图片resize
# transforms.RandomCrop(224), # 随机裁剪224*224,但此处不能随机裁剪,因为坐标群不能及时变化
# transforms.RandomHorizontalFlip(), # 水平翻转,但此处不能随机裁剪,因为坐标群不能及时变化
transforms.ToTensor(), # 将图像转为Tensor,数据归一化了欸!img.float().div(255)
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化,为了展示图像,先注释掉
])
# 测试集采用test_transform做变换
test_transform = transforms.Compose([
transforms.Resize([288, 288]), # 把图片resize为256*256
transforms.RandomCrop(256), # 随机裁剪224*224,测试时无标签
transforms.RandomHorizontalFlip(), # 水平翻转
transforms.ToTensor(), # 将图像转为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])
class KeyPointsDataSet(Dataset):
"""服装关键点标记数据集"""
def __init__(self, root_dir, image_set='train', transforms=None):
"""
初始化数据集
:param root_dir: 数据目录(.csv和images的根目录)
:param image_set: train训练,val验证,test测试
:param transforms(callable,optional):图像变换-可选
标签数据文件格式为csv_file: 标签csv文件(内容:图像相对地址-category类型-标签coordination坐标)
"""
self._imgset = image_set
self._image_paths = [] # 用于存储图片地址列表
self._labels = [] # 图片标签坐标群
self._cates = [] # 标签:服装类别
self._csv_file = os.path.join(root_dir, image_set + '.csv') # csv标签文件地址
self.__getFileList() # 获取数据(图像,坐标,类型)
self._categories = ['blouse', 'outwear', 'dress', 'trousers', 'skirt', ]
self._root_dir = root_dir
self._transform = transforms
def __len__(self):
return len(self._image_paths)
def __getitem__(self, idx):
img_id = self._image_paths[idx]
img_id = os.path.join(self._root_dir, img_id)
image = Image.open(img_id).convert('RGB') # [3, 256, 256](通道数,高,宽)= (c, h, w)
imgSize = image.size # 原始图像宽高
label = np.asfortranarray(self._labels[idx]) # (x, y, 显隐)=(宽,高,显隐性)
category = self._categories.index(self._cates[idx]) # 0,1,2,3,4
if self._transform:
image = self._transform(image) # 返回torch.Size([3, 256, 256])
else:
image.resize((256, 256)) # 使用resize
afterSize = image.numpy().shape[1:] # 缩放后图像的宽高
# print(imgSize, afterSize)
bi = np.array(afterSize) / np.array(imgSize)
label[:, 0:2] = label[:, 0:2] * bi
return image, label, category
def __getFileList(self):
file_info = pd.read_csv(self._csv_file)
self._image_paths = file_info.iloc[:, 0] # 第一列,相对地址列
self._cates = file_info.iloc[:, 1] # 第二列,服装类型:blouse,trousers,skirt,dress,outwear
if self._imgset == 'train':
landmarks = file_info.iloc[:, 2:26].values # panda中DataFrame数据的读取。第3-25列为坐标群,共24组坐标,
for i in range(len(landmarks)): # 处理坐标数据84_497_1 to [84,497,1]
label = []
for j in range(24):
plot = landmarks[i][j].split('_')
coor = []
for per in plot:
coor.append(int(per))
label.append(coor)
self._labels.append(np.concatenate(label))
self._labels = np.array(self._labels).reshape((-1, 24, 3))
else:
self._labels = np.ones((len(self._image_paths), 24, 3)) * (-1)
def showImageAndCoor(img, coords):
for coor in coords:
if coor[2] == -1:
pass
else:
img[:, coor[1]-1, coor[0]] = [255, 0, 0] # (y,x)
img[:, coor[1], coor[0]] = [255, 0, 0] # 设置关键点位置坐标为红色,为了便于观察,将关键点四领域都设置为红色
img[:, coor[1]+1, coor[0]] = [255, 0, 0]
img[:, coor[1], coor[0]-1] = [255, 0, 0]
img[:, coor[1], coor[0]+1] = [255, 0, 0]
# 因为Image不能直接接受(3,256,256)多维数据,需要逐个击破
# img = Image.fromarray(img * 255, mode='RGB') # 所以这样写只能得到一条线,呵呵
img0 = Image.fromarray(255 * img[0]).convert('L')
img1 = Image.fromarray(255 * img[1]).convert('L')
img2 = Image.fromarray(255 * img[2]).convert('L')
img = Image.merge("RGB", [img0, img1, img2])
img.show()
if __name__ == "__main__":
fashionDataset = KeyPointsDataSet(root_dir=r"E:/Datasets/Fashion/Fashion AI-keypoints_24/train/",
image_set="train",
transforms=transform,
)
dataloader = DataLoader(dataset=fashionDataset, batch_size=4)
for i_batch, data in enumerate(dataloader, 0):
img, label, category = data
img, label, category = img.numpy(), label.numpy(), category.numpy() # 'torch.Tensor'不能直接显示,需要转换程io能处理的numpy数组格式。
print(img.shape, label.shape, category)
showImageAndCoor(img[0], label[0])
break
注意事项:
- Image.open()读图的数据结构也是为(c, h, w)=(通道,高,宽),坐标组是(宽x, 高y)。统一伸缩时注意对应。
- 制作数据集生成器后,再显示图像数据时,需要重构图像mat,才能显示图像。即:
img0 = Image.fromarray(255 * img[0]).convert('L')
img1 = Image.fromarray(255 * img[1]).convert('L')
img2 = Image.fromarray(255 * img[2]).convert('L')
img = Image.merge("RGB", [img0, img1, img2])
- PIL Image show图像时,调用的是系统展示图像默认的工具。