import os.path
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset
import re
from functools import reduce
from torch.utils.tensorboard import SummaryWriter as Writer
#需要安装pip install tb-nightly
#正则表达式匹配出最后的数字:12
#print(re.findall("(\d+)","flower")[-1])
#创建自定义DataSet类
class myDataSet(Dataset):
#每个分类的子文件夹独立成一个标签数据集,标签例如flower0
def __init__(self,rootdir,labeldir):
self.rootdir=rootdir
self.labeldir=labeldir
self.imagePaths=os.path.join(rootdir,labeldir)
'''
#item作为编号:opencv版本
def __getitem__(self, item):
imagePath=os.listdir(self.imagePaths)[item]
imagePath=os.path.join(self.imagePaths,imagePath)
img=cv2.imdecode(np.fromfile(imagePath,np.uint8),-1)
#bgr转rgb
img = img[:, :, ::-1]
labelComopent =re.findall("(\d+)",self.labeldir)
#如果在标签中取不出对应tag
if len(labelComopent)==0:
raise ValueError
label=int(labelComopent[-1])
return img,label
'''
#item作为编号:opencv版本
def __getitem__(self, item):
imagePath=os.listdir(self.imagePaths)[item]
imagePath=os.path.join(self.imagePaths,imagePath)
img=Image.open(imagePath)
img = np.array(img)
labelComopent =re.findall("(\d+)",self.labeldir)
#如果在标签中取不出对应tag
if len(labelComopent)==0:
raise ValueError
label=int(labelComopent[-1])
return img,label
def __len__(self):
return len(self.imagePaths)
#使用r标识路径防止转义:
rootdir=r"D:\17flowers"
labelList=os.listdir(rootdir)
allDataSet=[]
#生成各子数据集
for label in labelList:
allDataSet.append(myDataSet(rootdir,label))
'''
reduce() 函数会对参数序列中元素进行累积。
函数将一个数据集合(链表,元组等)中的所有数据进行下列操作:
用传给 reduce 中的函数 function(有两个参数)先对集合中的第 1、2 个元素进行操作,得到的结果再与第三个数据用 function 函数运算,最后得到一个结果。
'''
trainDataSet=reduce(lambda x,y:x+y,allDataSet)
#载入日志写入器:
writer=Writer("./myBorderText")
for index,datas in enumerate(trainDataSet):
#print(datas[1])
writer.add_scalar("labelb标识",scalar_value=datas[1],global_step=index)
writer.close()
#查看命令:tensorboard --logdir=./myBorderText
当绘制图像时,使用:
writer=Writer("./myBorderText")
for index,datas in enumerate(trainDataSet):
#存储100张图像:
if index>100:
break
#(500, 689, 3)
#print(datas[0].shape)
#注意使用dataformats转变输入图像通道顺序:
writer.add_image("图片预览",img_tensor=datas[0],global_step=index,dataformats="HWC")
writer.close()