PyTorch学习(二)

PyTorch学习

Dataset类的使用

from torch.utils.data import Dataset
from PIL import Image
import os


class MyDataset(Dataset):

 def __init__(self, root_dir, label_dir):
     # 初始化文件路径或文件名列表
     self.root_dir = root_dir
     self.label_dir = label_dir
     self.path = os.path.join(self.root_dir, self.label_dir)
     self.image_path = os.listdir(self.path)

 def __getitem__(self, idx):
     # 1. 从文件中读取一个数据(例如使用 numpy.fromfile、Image.open)。
     # 2. 预处理数据(如 torchvision.Transform)。
     # 3. 返回图像和标签。
     img_name = self.image_path[idx]
     img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
     img = Image.open(img_item_path)
     label = self.label_dir
     return img, label

 def __len__(self):
     # 返回数据集的总大小
     return len(self.image_path)


root_dir = 'dataset/train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
ants_dataset = MyDataset(root_dir, ants_label_dir)
bees_dataset = MyDataset(root_dir, bees_label_dir)
# 类型相同可以进行相加操作
train_dataset = ants_dataset + bees_dataset


Tensorboard的使用

SummaryWriter

from torch.utils.tensorboard import SummaryWriter

以下是一些 SummaryWriter 的常用函数及其用法和参数:

  1. __init__: 用于创建 SummaryWriter 实例。

    • log_dir: 指定保存日志的目录,默认为当前目录下的 runs 文件夹。例如:‘./runs/my_experiment’。

    • comment: 为日志目录添加注释,方便区分不同的运行结果。

    writer = SummaryWriter()
    # 默认保存到./runs/May04_22-14-54_s-MacBook-Pro.local/
    writer = SummaryWriter(log_dir="logs")
    # 将日志文件保存到当前文件夹下的logs文件夹内 ./logs/
    writer = SummaryWriter(comment="LR_0.1_BATCH_16")
    # 输出尾缀增加comment内容
    
  2. add_scalar: 记录标量数据。

    • tag: 数据的标签,用于在 TensorBoard 中标识数据。

    • scalar_value: 要记录的标量值。(Y)

    • global_step: 当前步骤的编号,用于在图表中定位数据点。(X轴)

    for i in range(100):
        writer.add_scalar("y=x", i, i)
    # 绘制y=x的图像
    

    若使用同一个tag则会将数据混在同一幅图中

  3. add_image: 记录图像数据。

    • tag: 图像的标签。

    • img_tensor: 要记录的图像张量。

      可以是torch.Tensor, numpy.array, or string/blobname

    • global_step: 当前步骤的编号。

    • dataformats: 指定图像数据的格式,如 ‘CHW’ 或 ‘HWC’。

    dataformats默认为 'CHW’格式(tensor),若shape为其他格式将会报错

    如果图片是以numpy格式输入,需要指定dataformats为’HWC’

  4. close: 关闭 SummaryWriter,释放资源。

需要指定tensorboard端口号时

tensorboard --logdir=logs --port=6007
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值