# 显示一个batch的训练集数据
images, labels =next(iter(train_loader))# torchvision.utils.make_grid() 用于可视化一组图像,常用于显示图像样本
img = utils.make_grid(images)
img = img.numpy().transpose(1,2,0)# In the context of image data,# this is often needed because PyTorch represents images in the format (C, H, W),# where C is the number of channels (in this case, 1 for grayscale), H is the height, and W is the width.# Matplotlib expects images in the format (H, W, C).# So, the transpose swaps the second (H) and third (W) dimensions to match the required format.# 在导入图片时将图片归一化normalize了,在显示的时候就会出现问题,所以要先算回去
std =[0.5]
mean =[0.5]
img = img * std + mean
# 循环输出labelfor i inrange(64):print(labels[i], end=" ")
i +=1if i%8==0:print(end='\n')#
plt.imshow(img)
plt.show()
CNN网络
# CNN 网络# 继承nn.Module类classCNN(nn.Module):def__init__(self):super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1,32, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(32,64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64*7*7,1024)
self.fc2 = nn.Linear(1024,512)
self.fc3 = nn.Linear(512,10)defforward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,64*7*7)# 将数据平整为一维的
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)return x
net = CNN()