一、简介
PyTorch是一个开源的神经网络框架,专门针对 GPU 加速的深度神经网络编程。
二、加载数据
Dataset:获取数据中有用的部分及对其编号(标签label)、计算数据量
Dataloader:提供不同包装数据的形式,即batch-size
🌈 Dataset类编写
首先,先准备一个数据集 ,这是一个关于蚂蚁和蜜蜂的数据集
编写代码
import os
from torch.utils.data import Dataset
from PIL import Image
class MyData(Dataset):
def __init__(self, dir_root, dir_label):
self.dir_root = dir_root # 定义为此类全局变量
self.dir_lable = dir_label
self.path = os.path.join(self.dir_root, self.dir_lable) # 获取图片路径
self.img_name_list = os.listdir(self.path) # 获取所有图片的名称列表
def __getitem__(self, idx):
img_name = self.img_name_list[idx] # 获取指定图片的名称
img_item_path = os.path.join(self.dir_root, self.dir_lable, img_name) # 获取指定图片的路径
img = Image.open(img_item_path) # 打开指定图片
label = self.dir_lable
return img, label
def __len__(self):
return len(self.img_name_list)
root = "dataset/train" # 换成自己的路径
Mylabel_ants = "ants" # 换成自己的标签
Mylabel_bees = "bees"
ants_data = MyData(root, Mylabel_ants) # 蚂蚁数据集
bees_data = MyData(root, Mylabel_bees) # 蜜蜂数据集
all_data = ants_data + bees_data # 蜜蜂和蚂蚁数据集合集
ants_data 包含了蚂蚁数据集,bee_data包含了蜜蜂数据集
🌈Dataloader类
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
data_trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=data_trans, download=True)
# 测试集
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=data_trans, download=True)
# 参数说明
# dataset:需加载的数据集
# batch_size:一次加载几张图片
# shuffle:是否乱序加载
# num_workers:单线程加载
# drop_last:不足一个batch_size的数据是否丢弃
loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
writer = SummaryWriter("logs")
step = 0
for data in loader:
img, target = data
writer.add_images("loader", img, step)
step += 1
writer.close()
用tensorboard查看效果
三、TensorBoard
🌈添加标量
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs') # 文件存储到logs
for i in range(100):
writer.add_scalar("y=x", i, i) # 添加标量
writer.close()
pycharm终端输入
tensorboard --logdir=logs
打开日志
🌈添加图片
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter('logs') # 文件存储到logs
img_path = "dataset/train/ants/0013035.jpg"
img_PIL = Image.open(img_path) # 获取PIL类型的图片
img_numpy = np.array(img_PIL) # 转化为numpy类型
writer.add_image('tensor_test', img_numpy, 1, dataformats='HWC')
writer.close()
打开日志
四、Transforms
功能:转化图片格式 (tensor),以便进入神经网络处理
用法:
🌈ToTensor类:先建立一个对象,再把要转化的图片作为参数传进去
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path = "dataset/train/ants/0013035.jpg"
img = Image.open(img_path)
tensor_trans = transforms.ToTensor() # 创建tensor转化对象
tensor_img = tensor_trans(img)
writer = SummaryWriter('logs')
writer.add_image("tensor_image", tensor_img, 1) # 添加图片
writer.close()
🌈Normalize类:将tensor类型的图片按给定的均值和标准差进行归一化,归一化作用是为了消除奇异值,加快训练速度,公式如下
output[channel] = (input[channel] - mean[channel]) / std[channel]
假定均值为0.5,标准差为0.5,则若input属于[0,1],则output属于[-1,1]
# 假定RGB三通道的均值和标准差均为0.5
norm_trans = transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
norm_img = norm_trans(tensor_img)
🌈 Resize类:将指定图片进行resize成合适大小,格式是PIL
print(img.size)
resize_trans = transforms.Resize([512, 512])
resize_img = resize_trans(img)
print(resize_img.size)
🌈Compose类:Resize+ToTensor,需提供转换列表
compose_trans = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])
compose_img = compose_trans(img)
🌈RandomCrop类:随机裁剪
compose_trans1 = transforms.Compose([transforms.RandomCrop(512), transforms.ToTensor()])
compose_img1 = compose_trans1(img)
🌈torchvision数据集下载
import torchvision
train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
数据集与transform联动,PIL->Tensor
data_trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=data_trans, download=True)
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=data_trans, download=True)