我可视化的图,
出现这种问题是因为,我把dataloader里面拿出来的直接可视化了
train_dataset = torchvision.datasets.CIFAR10(root='data/',
train=True,
transform=transforms.Compose(
[transforms.Scale((32,32)),transforms.ToTensor(),transforms.Normalize(mean=(0,0,0), std=(1,1,1)),]), # 变到和vgg一样的输入
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=64, #该参数表示每次读取的批样本个数
shuffle=True)
for i, (img, lab) in enumerate(train_loader):
show_img(img,lab_pre)
应该先把tensor变为PIL格式再可视化:
from torchvision.transforms import ToPILImage
img = ToPILImage(img)