# -*- coding: utf-8 -*-
# @Author : LPH
# @File : MyDataset.py
# @Software: PyCharm
import os
import glob
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
# 加载我的自定义图片数据集
class LoadMyDataset(Dataset):
def __init__(self, root, img_file_list, transform=None):
super(LoadMyDataset, self).__init__()
self.root = root
self.transform = transform
# 存储图片的路径以及对应标签的元组列表
self.samples = []
for i, file_path in enumerate(img_file_list):
self.samples.append((file_path, 0))
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
# 对图像进行处理再返回
file_path, target = self.samples[index]
image = Image.open(file_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, target
def main():
# 设置自定义图片的路径
root = "./MyImg/HuGe"
# 读取图像列表,使用glob库,匹配路径下的所有PNG图片 会返回一个数组里面包含所有的图像的路径
image_list = glob.glob(os.path.join(root, "*.png"))
# print(image_list)
# 定义图像转换方式 先修改图像 之后再转化为Tensor类型
trans = transforms.Compose([transforms.Resize((100, 100)), transforms.ToTensor()])
# 实例化对象
myDataset = LoadMyDataset(root=root, img_file_list=image_list, transform=trans)
# 数据加载
batch_size = 16
dataloader = DataLoader(dataset=myDataset, batch_size=batch_size, shuffle=False)
# 读取所有的图片 使用tensorboard进行展示
writer = SummaryWriter("./logs")
for step in range(len(dataloader)):
for data in dataloader:
imgs, targets = data
writer.add_images(f"HuGe-{step}", imgs, step)
writer.close()
if __name__ == '__main__':
main()
tensorboard展示
tensorboard --logdir="logs"
效果如下: