#9.计算train_ds数据集的均值与标准差
import numpy as np
meanRGB = [np.mean(x.numpy(), axis=(1,2)) for x,_ in train_ds]
stdRGB = [np.std(x.numpy(),axis=(1, 2) for x,_ in train_ds]
meanR = np.mean([m[0] for m in meanRGB])
meanG = np.mean([m[1] for m in meanRGB])
meanB = np.mean([m[2] for m in meanRGB])
stdR = np.mean([s[0] for s in stdRGB])
stdG = np.mean(s[1] for s in stdRGB])
stdB = np.mean(s[2] for s in stdRGB])
print(meanR, meanG, meanB)
print(stdR, stdG, stdB)
# 0.4467106 0.43980986 0.40664646
# 0.22414584 0.22148906 0.22389975
#10.分别定义train_ds和test0_ds的数据变换
train_transformer =
transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize([meanR, meanG, meanB], [stdR, stdG, stdB],
])
#11 分别更新train_ds 和 test0_ds的transform函数
train_ds.transform = train_transformer
test0_ds.transform = test0_transformer
计算数据集的均值和标准差
最新推荐文章于 2023-06-19 14:15:26 发布