PyTorch学习
Dataset类的使用
from torch.utils.data import Dataset
from PIL import Image
import os
class MyDataset(Dataset):
def __init__(self, root_dir, label_dir):
# 初始化文件路径或文件名列表
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.image_path = os.listdir(self.path)
def __getitem__(self, idx):
# 1. 从文件中读取一个数据(例如使用 numpy.fromfile、Image.open)。
# 2. 预处理数据(如 torchvision.Transform)。
# 3. 返回图像和标签。
img_name = self.image_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
# 返回数据集的总大小
return len(self.image_path)
root_dir = 'dataset/train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
ants_dataset = MyDataset(root_dir, ants_label_dir)
bees_dataset = MyDataset(root_dir, bees_label_dir)
# 类型相同可以进行相加操作
train_dataset = ants_dataset + bees_dataset
Tensorboard的使用
SummaryWriter
from torch.utils.tensorboard import SummaryWriter
以下是一些 SummaryWriter
的常用函数及其用法和参数:
-
__init__
: 用于创建 SummaryWriter 实例。-
log_dir: 指定保存日志的目录,默认为当前目录下的
runs
文件夹。例如:‘./runs/my_experiment’。 -
comment: 为日志目录添加注释,方便区分不同的运行结果。
writer = SummaryWriter() # 默认保存到./runs/May04_22-14-54_s-MacBook-Pro.local/ writer = SummaryWriter(log_dir="logs") # 将日志文件保存到当前文件夹下的logs文件夹内 ./logs/ writer = SummaryWriter(comment="LR_0.1_BATCH_16") # 输出尾缀增加comment内容
-
-
add_scalar: 记录标量数据。
-
tag: 数据的标签,用于在 TensorBoard 中标识数据。
-
scalar_value: 要记录的标量值。(Y)
-
global_step: 当前步骤的编号,用于在图表中定位数据点。(X轴)
for i in range(100): writer.add_scalar("y=x", i, i) # 绘制y=x的图像
若使用同一个tag则会将数据混在同一幅图中
-
-
add_image: 记录图像数据。
-
tag: 图像的标签。
-
img_tensor: 要记录的图像张量。
可以是
torch.Tensor, numpy.array, or string/blobname
-
global_step: 当前步骤的编号。
-
dataformats: 指定图像数据的格式,如 ‘CHW’ 或 ‘HWC’。
dataformats默认为 'CHW’格式(tensor),若shape为其他格式将会报错
如果图片是以numpy格式输入,需要指定dataformats为’HWC’。
-
-
close: 关闭 SummaryWriter,释放资源。
需要指定tensorboard端口号时
tensorboard --logdir=logs --port=6007