一.Totensor transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
"""Convert a PIL Image or ndarray to tensor and scale the values accordingly.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8
二.Normalize(数据标准化)
"""Normalize a tensor image with mean and standard deviation. This transform does not support PIL Image. Given mean: ``(mean[1],...,mean[n])#均值`` and std: ``(std[1],..,std[n])#标准差`` for ``n`` channels, this transform will normalize each channel of the input ``torch.*Tensor`` i.e., ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
def imshow(img):
img = img/2 +0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg,(1,2,0))) #本来是CHW,换成HWC,120是把本来012的位置倒换
plt.show()
loss_function = nn.CrossEntropyLoss() #
r"""This criterion computes the cross entropy loss between input logits and target.已经包含nn.LogSoftmax 函数
optimizer.zero_grad()清空历史梯度
tensorflow 中的通道[batch,height,width,channel]HWC