pytorch加载数据
主要涉及到两个类:
Dataset
提供一种方法去获取数据和label。
- 如何获取每一个数据及其label。
- 告诉我们一共有多少数据。
Dataloader
为后面的网络提供不同的数据形式。
举例
如数据集:https://download.pytorch.org/tutorial/hymenoptera_data.zip
from torch.utils.data import Dataset #需要重写__getitem__方法和选择重写__len__方法。
import PIL import Image
import os
img_path = "D:\\PycharmProjects\\pytorch\\dataset\\train\\ants\\0013035.jpg"
dir_path = "dataset/train/ants"
img_path_list = os.listdir(dir.path) #这里是一个list,list里面是所有蚂蚁图片的名称(str类型)
class MyData(Dataset):
def __init__(self, root_dir, label_dir): #root_dir = "dataset/train",label_dir = "ants"
self.root_dir = root_dir
self.label_dir = label_dir
self.img_path_list = os.listdir(self_path) #这里是一个list,list里面是所有蚂蚁图片的名称(str类型)
def __getitem__(self,idx):
img_name = self.img_path_list[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path )
label = self.label_dir
return img, label
def __len__():
return len(self.img_path_list)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
img, label = ants_dataset[0]#获取蚂蚁数据集的第一个数据和label
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + ants_dataset #可以直接相加整合
TensorBoard的使用
TensorBoard主要用来显示训练过程的一些结果,如训练过程loss是如何变化的,图像是如何变化的。。
先到需要的环境下安装该包,pip install tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs") #创建logs文件夹,后续的事件文件都会输入到该文件夹内。
# writer.add_image()
# 例:y = x
for i in range(100):
writer.add_scalar("y = x", i, i)
# 例 y = 2x
for i in range(100):
writer.add_scalar("y = 2x", 2*i, i)
for i in range(100):
writer.add_scalar("y = 2x", 3*i, i) #如果名字不改的话,数据都会写入"y = 2x"这幅图中。有后续的
#数据要写入的话,可以删掉原文件,或者新建文件夹存起来
writer.close()
然后在terminal中输入:
tensorboard --logdir=logs --port=6007 #logs为存放事件的文件夹,并且指定端口
进入到:http://localhost:6007即可,结果如下: