1. 基础环境配置
1. 配置anaconda
# 创建环境
conda create -n pytorch python=3.6
# 启动环境
conda activate pytorch
# 安装pytorch包 https://pytorch.org/
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
2. 配置PyCharm
创建项目–选择解释器–pytorch
可用Python控制台进行验证是否安装成功
3. 配置Jupyter Notebook
# 在pytorch环境中安装jupyter notebook
conda instll nb_conda
# 启动jupyter notebook
jupyter notebook
4. 工具的使用
dir() 查看有什么内容
help() 查看如何使用这些工具
实战使用
5. pycharm、控制台、jupyter notebook
2. 数据集
0. Dataset和Dataloader
1. 加载数据抽象类(Dataset)
# 加载数据抽象类Dataset
from torch.utils.data import Dataset
Dataset??
2. 获取数据
数据文件结构(此次数据集没有标签数据集,而是把文件名当作数据集标签)
一般正常的数据结构(分为图片源文件和标签源文件)
数据集代码实例如下,重载Dataset
from torch.utils.data import Dataset
from PIL import Image
import os
class Mydata(Dataset):
def __init__(self, root_dir, label_dir):
# root_dir 图片文件夹
self.root_dir = root_dir
# label_dir 标签文件夹
self.label_dir = label_dir
# 获取图片文件名
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
# 获取索引item的图片文件和标签
def __getitem__(self, item):
img_name = self.img_path[item]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
# 获取数据集长度
def __len__(self):
return len(self.img_path)
# 实例化
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
# 获取数据
img, label = train_dataset[0]
train_dataset_len = len(train_dataset)
3. tensorboard使用
1. 安装
注意tensorboard版本问题,否则后面会导致很多问题
pip install tensorboard
2. 使用
样例
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(100):
# 标题 y轴 x轴
writer.add_scalar("y = 2x", 2*i, i)
writer.close()
得到一个文件
查看该文件
tensorboard --logdir=logs --port=6007
打开页面得到样例图像
tensorboard 查看图片
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
writer = SummaryWriter("logs")
# 写入图片
image_path = "dataset/train/ants/0013035.jpg"
img_PIL = Image.open(image_path)
# 转换成tensorboard需要的图片格式
img_array = np.array(img_PIL)
# 添加,并设置成所需要的通道模式HWC
writer.add_image("ants", img_array, 1, dataformats='HWC')
writer.close()
4. transforms使用
1. 示意图
2. 代码
from PIL import Image
from torchvision import transforms
# 打开图片 格式为PIL的Image
image_path = "dataset/train/ants/0013035.jpg"
img_PIL = Image.open(image_path)
# 用transform进行格式转换
# 实例化transforms对象
tensor_trans = transforms.ToTensor()
# PIL格式 转换为 tensor格式
tensor_img = tensor_trans(img_PIL)
3. 为什么使用tensor
Numpy一个强大的数据操作的工具,但是它不能在GPU上运行,只有将Numpy中的ndarray
转换成tensor
, 才能在GPU上运行。所以我们在必要的时候,需要对ndarray
和tensor
进行操作,同时由于list
是一种我们在数据读取中经常会用到的数据结构,所以对于list
的操作也是经常用到的一种操作。下图就总结了它们之间互相转换的基本的操作。
4. 函数
- ToTensor()
- Resize()
- Normalize()
- Compose()
from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
# 写入日志文件
writer = SummaryWriter("logs")
# 打开图片
img_path = "images/tx.jpg"
img = Image.open(img_path)
# ToTensor, PIL-->Tensor
# 实例化ToTensor()
trans_tensor = transforms.ToTensor()
# 格式转换
img_tensor = trans_tensor(img)
# 写入logs
writer.add_image("tx", img_tensor, 1)
# Resize, 进行尺寸的裁剪
# 实例化Resize()
resize_tensor = transforms.Resize((33, 33))
# 尺寸转换
img_resize = resize_tensor(img_tensor)
# 写入logs
writer.add_image("tx", img_resize, 2)
# Normalize, 归一化
# 实例化Normalize()
normal_tensor = transforms.Normalize([3, 2, 1], [1, 2, 3])
# 归一化
img_normal = normal_tensor(img_tensor)
# 写入logs
writer.add_image("tx", img_normal, 3)
# 将上述操作,进行统一进行
# PIL -> ToTensor -> Resize -> Normalize -> tensor
trans_compose = transforms.Compose([trans_tensor, resize_tensor, normal_tensor])
# 进行组合操作
img_compose = trans_compose(img)
# 写入logs
writer.add_image("tx", img_compose, 4)
# 关闭写操作
writer.close()
ToTensor()
Resize()
Normalize()
Compose()
5. pytorch提供的数据集
获取torchvision视觉的数据集 CIFAR10
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
# 从torchvison加载CIFAR10数据集
data_transform = transforms.Compose([
transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10("./dataset1", train=True, transform=data_transform, download=True)
test_set = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=data_transform, download=True)
# 写入log进行验证
writer = SummaryWriter("logs")
<