第一步:下载数据集到本地,可以下载到同路径目录下(非常慢,可以直接复制链接(不显示路径可以ctrl进源代码查看,一般都会有)进迅雷下载,下载好将其复制到同名目录下运行过程中会自动解压)
- 注意Python 从 2.7.9版本开始,就默认开启了服务器证书验证功能,所以直接下会报错记得加
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import torchvision
# Python 从 2.7.9版本开始,就默认开启了服务器证书验证功能,
# 如果证书校验不通过,则拒绝后续操作;
# 这样可以防止中间人攻击,并使客户端确保服务器确实是它声称的身份。
# 如果是自签名证书,由于一般系统的CA证书中不存在在自签名的CA证书内容,从而导致证书验证不通过。
# 临时解决方案:如下两行
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# download=True下载数据集到本地
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False,download=True)
查看数据集中的数据,可以添加断点查看
# 查看测试集中第一个数据
print(test_set[0])
print(test_set.classes) # 类别
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target]) # target对应的classes
img.show() #展示图片
print(test_set[0]) # tensor([], 3)
在tensorboard中显示图片
- 关键步骤:将图片类型转为tensor数据类型
from torch.utils.tensorboard import SummaryWriter
# 将图片格式转化为tensor数据类型
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 将数据集中所有数据转化为tensor数据类型transform=dataset_transform
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# tensorboard进行显示
writer = SummaryWriter("kun")
# 显示测试数据集中前十张
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i) # img tensor
writer.close()
示例代码:
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 将图片格式转化为tensor数据类型
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# Python 从 2.7.9版本开始,就默认开启了服务器证书验证功能,
# 如果证书校验不通过,则拒绝后续操作;
# 这样可以防止中间人攻击,并使客户端确保服务器确实是它声称的身份。
# 如果是自签名证书,由于一般系统的CA证书中不存在在自签名的CA证书内容,从而导致证书验证不通过。
# 临时解决方案:如下两行
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# download=True下载数据集到本地
# 将数据集中所有数据转化为tensor数据类型transform=dataset_transform
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# # 查看测试集中第一个数据
# print(test_set[0])
# print(test_set.classes) # 类别
#
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target]) # target对应的classes
# img.show() #展示图片
# print(test_set[0]) # tensor([], 3)
# tensorboard进行显示
writer = SummaryWriter("kun")
# 显示测试数据集中前十张
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i) # img tensor
writer.close()