目录
一、Dataset类
获取图片地址
1.单张图片获取
需要用到Pillow库,简称PIL,它是一个Python的第三方库,是一个非常好用的图像处理库。
from PIL import Image
img_path='数据集\练手数据集\\train\\ants_image\\0013035.jpg'
img = image.open(img_path) #创建一个Image对象
img.show() #打开Image对象,弹出一个图片弹窗
即可获得一个包含了图片各种信息的Image对象。
2.将多张图片打包为列表
利用OS标准库中listdir方法,OS表示Operating System,即操作系统。
import os
dir_path="数据集\练手数据集\\train\\ants_image"
img_path_list=os.listdir(dir_path) #获得地址列表
创建dataset类
from torch.utils.data import Dataset
import os
from PIL import Image
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
# 将以上路径组合
self.path = os.path.join(self.root_dir,self.label_dir)
# 获得打包路径
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
# 获取图片名
img_name = self.img_path[idx]
# 获取图片地址
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
# 将该图片打开
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
以蚂蚁蜜蜂数据集为例
root_dir = "数据集\练手数据集\\train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = MyData(root_dir,ants_label_dir) #蚂蚁dataset
bees_dataset = MyData(root_dir,bees_label_dir) #蜜蜂dataset
由于class中使用了__getitem__方法,所以可以直接用索引取用需要的图片,例如
#直接展示
ants_dataset[0]
Out[4]: (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=768x512>, 'ants_image')
#获取img和label
img,label = ants_dataset[0]
img.show()
也可以将两个类相加,获得整个数据集
train_dataset = ants_dataset + bees_dataset # 将两个类相加
len(ants_dataset)
Out[7]: 124
# 蚂蚁类有124张图片
# 所以在train_dataset中0-123为蚂蚁类
# 124之后为蜜蜂类
img,label = train_dataset[123]
img.show() # 蚂蚁图片
img,label = train_dataset[124]
img.show() # 蜜蜂图片
二、TensorBoard的使用
1.导入SummaryWriter
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs") # 创建一个名为logs的文件夹
# 会使用到的两种方法
writer.add_image()
writer.add_scalar()
2.使用add_scalar()
可以用于绘制图像。
使用实例:
# 绘制一个y=x的图像
for i in range(100):
writer.add_scalar("y=x",i,i)
在logs文件夹中会出现一个事件,可以用Terminal打开,代码为:
tensorboard --logdir=logs # logdir表示事件所在文件夹
当多次生成事件在同一文件夹后会发现图像会冗余,这种情况可以删除整个文件夹并重建子文件夹,或删除对应文件夹下的所有事件,即可解决。
3.使用add_image()
参数:add_image( tag , img_tensor , global_step)
实例:
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
writer = SummaryWriter("logs")
# 获取图片路径
image_path = "数据集\练手数据集\\train\\ants_image\\0013035.jpg"
# 打开图像
img_PIL = Image.open(image_path)
# 转换为ndarray
img_array = np.array(img_PIL)
writer.add_image("test",img_array,1,dataformats='HWC')
# 当转换成numpy类型后,数据为(H,W,C),所以需要加上dataformats='HWC',不然之后会报错。
打开事件后会发现多了image一栏。
多次在同一img_tensor下,仅改变global_step并使用add_image()后可以滑动图片。
三、Transforms的使用
transfroms是pytorch的预处理模块。
1.ToTensor()
ToTensor能将PIL Image或numpy.ndarray数据类型转换为Tensor格式。
使用实例:
from PIL import Image
from torchvision import transforms
image_path = "数据集\练手数据集\\train\\ants_image\\0013035.jpg"
image = Image.open(image_path)
# 创建一个对象
tensor_trans = transforms.ToTensor()
# 由于ToTensor中有__call__方法
# 因此可以直接将对象作为函数使用
# def __call__(self, pic)
# 输入参数为PIL Image or numpy.ndarray
tensor_img = tensor_trans(image)