趁着假期,准备好好学习PyTorch,以下是torchversion的学习笔记。在此,特别感谢B站UP主【我是土堆】,土堆老师讲解得非常认真详细,各位pytorch入门学习者可以去看看~
视频链接:https://www.bilibili.com/video/BV1hE411t7RN?p=15&spm_id_from=pageDriver
一、预备
1. PyTorch官网
PyTorchhttps://pytorch.org/2. 进入torchversion
3.向下滑动,找到torchvision.datasets
在这里就可以找到各种数据集的下载方式了,比如:MINIST、COCO
4. 查看样例
以物体识别数据集CIFAR为例,点击下图红框中的内容,即可查看样例
5.其他
二、数据集的使用
1.下载数据集及简单的调用
以CIFAR数据集为例
import torchvision
#root 指定路径,train=True 下载训练数据集,train=False 下载测试数据集,download=Ture 在线下载数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
img,target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
2.与transforms结合起来使用
import torchvision
from torch.utils.tensorboard import SummaryWriter
#定义数据变换
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
#root 指定路径,train=True 下载训练数据集,train=False 下载测试数据集,download=Ture 在线下载数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True,transform=dataset_transform)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True,transform=dataset_transform)
#使用tensorboard来显示前10张图片
writer = SummaryWriter("p10")
for i in range(10):
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()
tensorboard显示如下:
以上就是我的学习笔记,再次感谢土堆老师!