在 PyTorch 中,tensor.shape
返回一个包含张量各维度大小的元组。
所以,当你执行 print(img.shape)
,你看到的 (3, 32, 32)
实际上是在告诉你:
- 这是一个三维张量
- 第一维(通道)的大小是 3
- 第二维(高度)的大小是 32
- 第三维(宽度)的大小是 32
-
import torchvision from torch.utils.data import DataLoader test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True,transform=torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False) img,targert = test_set[0] print(img.shape) print(targert)
参数含义
dataset=test_set
: 指定要加载的数据集batch_size=4
: 每批加载 4 个样本shuffle=True
: 随机打乱数据顺序num_workers=0
: 不使用多进程加载数据drop_last=False
: 不丢弃最后一个不完整的批次
import torchvision
from torch.utils.data import DataLoader
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
img,targert = test_set[0]
print(img.shape)
print(targert)
for data in test_loader:
imgs,targerts = data
print(imgs.shape)
print(targerts)
tensorboard上数据集可视化step老跳步:
终端运行的命令后面再加上一句 --samples_per_plugin=images=1000
完整代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_set,batch_size=64,shuffle=False,num_workers=0,drop_last=True)
img,targert = test_set[0]
print(img.shape)
print(targert)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs,targerts = data
# print(imgs.shape)
# print(targerts)
writer.add_images("epoch:{}".format(epoch),imgs,step)
step+=1
writer.close()