代码来自博客的transforms.normalize如何对特定数据集设定标准化参数,但报错了,代码段和报错如下:
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
def getStat(train_data):
'''
Compute mean and variance for training data
:param train_data: 自定义类Dataset(或ImageFolder即可)
:return: (mean, std)