Pytorch Vision操作和CIFAR10数据集

        以下代码都会用到以下库

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

1.下载数据集

        首先要定义对图片的处理方式,Compose 是 torchvision.transforms 模块中的一个类。它允许你将多个图像转换操作串联起来。当你创建一个 Compose 对象时,你需要提供一个包含多个转换操作的列表,这些操作将按照列表中的顺序逐个应用到图像上。简单来说,就是将多个图像转换操作组合成一个单一的操作序列。

        比如以下代码的([torchvision.transforms.ToTensor()])就代表这个dataset_transforms会对图片进行 

#将获得的数据集PIL图片转化为tensor
dataset_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

        再复杂一点和常见点的图像变化组合如下,只是方便理解,不加入代码里,本文的代码比较基础,是上面的代码加入后续代码:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),  # 随机裁剪到指定大小
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.ToTensor(),              # 转换为张量
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 归一化
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),              # 调整大小
        transforms.CenterCrop(224),          # 中心裁剪到指定大小
        transforms.ToTensor(),               # 转换为张量
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 归一化
    ]),
}

        接下来下载数据集,并划分训练集和测试集:

# 选择使用CIFAR10数据集,root是当前相对地址下的文件夹,没有会创建一个
# download = true的时候会在运行的时候,自己下载对应的数据集
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transforms,download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False,transform=dataset_transforms, download=True)

        root是指定数据集的存储目录。在这个例子中,数据集会存储在当前目录下的 dataset 文件夹中。如果该文件夹不存在,会自动创建。

        train指定加载训练集还是测试集。当 train=True 时,加载训练集;当 train=False 时,加载测试集。

        transform指定要应用于图像的转换操作。在这个例子中,使用之前定义的dataset_transforms,将图像从 PIL 图像转换为 PyTorch 张量。

2.探索数据集图像

# 输出数据集中的第一个样本
print(test_set[0])
print("________________________")

# 输出数据集的类别标签
# test_set.classes 返回一个列表,包含 CIFAR-10 数据集的所有类别标签及其对应的索引。例如,'airplane': 0 表示类别 "airplane" 对应的索引是 0。
print(test_set.classes)
print("________________________")

# 提取第一个样本的图像和标签,并分别输出
img, target = test_set[0]
print(img)    # 图像数据,是一个张量。
print("________________________")
print(target)    # 标签,表示图像所属的类别索引。
print("________________________")
print(test_set.classes[target])    # 标签对应的类别名称。

3.记录到tensorboard

        从 test_set 中提取前 10 张图像,并将这些图像记录到 TensorBoard 日志文件中,以便可视化。

# 输出测试集里的第一个样本,通常是(tensor,target)形式
print(test_set[0])

# 写入当前logs文件夹里
writer = SummaryWriter('./logs')

# 10次循环,命名为test_img,每个img为第i步
for i in range(10):
    img,target = test_set[i]
    writer.add_image('test_img',img,i)

writer.close()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值