1.首先建立数据集;包括:训练集、测试集;设置相关参数
import torchvision
from torch.utils.tensorboard import SummaryWriter
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
train_set代表训练集,test_set代表测试集;CIFAR10数据集可以在pytorch网站下载;
root参数代表数据集保存地址;train的参数:为True代表训练集,False代表测试集
transform参数:使用dataset_transform将图片转换为tensor类型
2.将图片从PIL类型转换为tensor类型:dataset_transform
# 将图片类型转换为tensor类型
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
3.建立日志文件,保存到logs文件中,用于TensorBoard可视化展示
writer = SummaryWriter("logs")
# 显示测试集中的前10张图片
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i) # TensorBoard展示,step设置为i
writer.close()
4.TensorBoard可视化展示,使用终端命令:tensorboard --logdir="logs" 打开TensorBoard可视化网页界面,其中logs代表生成的日志文件