【pytorch学习笔记01】加载数据&Tensorboard

B站我是土堆视频学习笔记,链接:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.999.0.0

1. 加载数据

Dataset

提供一种方式去获取数据及其label

  • 如何获取每个数据及其label
  • 告诉我们总共有多少的数据

下面这个示例,是根据一个自己的数据集定义的,主要包括3步:

  1. def __init__(self, root_dir):初始化,需要定义好各种路径
  2. def __getitem__(self, idx):获取每个数据及其label,输入为第几个数据
  3. def __len__(self):数据集长度
from torch.utils.data import Dataset
import cv2
import os

# 构建一个自己的Dataset
class MyData(Dataset):

    # 初始化,需要定义好各种路径
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        # path下面就是所有的数据
        self.path = os.path.join(self.root_dir, self.label_dir)
        # img_names,保存了path下面所有数据的名称
        self.img_names = os.listdir(self.path)
	
    # 获取每个数据及其label,输入为第几个数据
    def __getitem__(self, idx):
        img_name = self.img_names[idx]  # 读取数据名称
        img_item_path = os.path.join(self.path, img_name)  # 构建指定数据路径
        img = cv2.imread(img_item_path)  # 读取数据
        label = self.label_dir  # 读取标签
        return img, label

    # 数据集长度
    def __len__(self):
        return len(self.img_path)

Dataloader

加载数据

属性

这里参考了:https://blog.csdn.net/rogerfang/article/details/82291464

1、dataset:(数据类型 dataset)

输入的数据类型,是原始数据的输入。

2、batch_size:(数据类型 int)

批训练数据量的大小,根据具体情况设置即可(默认:1)。

3、shuffle:(数据类型 bool)

洗牌。默认设置为False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

4、collate_fn:(数据类型 callable,没见过的类型)

将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)

5、batch_sampler:(数据类型 Sampler)

批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。

6、sampler:(数据类型 Sampler)

采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。

7、num_workers:(数据类型 Int)

进程数量,默认是0。设置为0,就是使用主进程来导入数据。

8、pin_memory:(数据类型 bool)

内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

9、drop_last:(数据类型 bool)

丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。

10、timeout:(数据类型 numeric)

超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

11、worker_init_fn(数据类型 callable,没见过的类型)

子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据。

示例
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10('./dataset',
                                         train=False,
                                         transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data,
                         batch_size=64,
                         shuffle=True,
                         num_workers=0,  # 只有一个主进程
                         drop_last=True)  # 是否保留最后的数据,barch_size的余数

# 测试数据集中的第一张图片
img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter('logs')
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, target = data
        writer.add_images('test_data_drop_last', imgs, step)
        step += 1

writer.close()

2. Tensorboard

安装、启动

安装
pip install tensorboard
启动
tensorboard --logdir=logs --port=6006
  • logdir:保存日志的地址
  • port:启动端口,默认为6006

注意:终端启动的路径要和 logdir 对应上

SummaryWriter

初始化

将条目直接写入log_dir中的事件文件,以供TensorBoard使用。

参数
def __init__(self, log_dir=None, comment='', purge_step=None, max_queue=10,
                 flush_secs=120, filename_suffix=''):
  • log_dir:保存事件文件路径
示例
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('logs')

# 其他操作

writer.close()

add_scalar()

向摘要中添加标量数据

参数
add_scalar(
    self,
    tag,
    scalar_value,
    global_step=None,
    walltime=None,
    new_style=False,
    double_precision=False,
):
  • tag:标题
  • scalar_value:要保存的值,x 轴
  • global_step (int):记录的全局步长值,y 轴
示例
# 一条 y=x 的曲线
for i in range(100):
    writer.add_scalar('y=x', i, i)
image-20230722105038950

add_image()

向摘要中添加图片

参数
add_image(
    self, 
    tag, 
    img_tensor, 
    global_step=None, 
    walltime=None, 
    dataformats='CHW'
):
  • tag:标题
  • img_tensor:图像数据
  • global_step (int):记录的全局步长值,y 轴

注意img_tensor的形状:

  • 默认为 ’ (3, H, W) '。
  • 可以使用’ ’ torchvision.utils.make_grid() ’ '将一批张量转换为3xHxW格式
  • 或者传递相应的’ ‘dataformats’ ‘参数。’ chw ', ’ hwc ', ’ hw '。
示例
img_path = 'data/train/ants_image/0013035.jpg'
img = Image.open(img_path)
img_array = np.array(img)

writer.add_image('test', img_array, 1, dataformats='HWC')

报错:AttributeError: module ‘PIL.Image‘ has no attribute ‘ANTIALIAS‘

原因:版本太高

解决:

pip uninstall -y Pillow
pip install Pillow==9.5.0
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值