PyTorch 中的Normalize
公式为:
output[channel] = (input[channel] - mean[channel]) / std[channel]
实例:
import numpy as np
import cv2.cv2 as cv
import os
import torch
from torch import random
from torchvision import transforms
import torchvision
src = cv.imread(r'code/test.jpg')
cv.imshow('src', src)
src = torch.from_numpy(src).float()/256
mean_b, var_b = torch.mean(src[:, :, 0]), torch.var(src[:, :, 0])
mean_g, var_g = torch.mean(src[:, :, 1]), torch.var(src[:, :, 1])
mean_r, var_r = torch.mean(src[:, :, 2]), torch.var(src[:, :, 2])
src = torch.swapdims(src, 0, 2)
src = torch.swapdims(src, 1, 2)
transform = transforms.Compose([
transforms.Normalize((mean_b, mean_g, mean_r), (var_b, var_g, var_r))
])
dst = transform(src)
dst = torch.swapdims(dst, 0, 2)
dst = torch.swapdims(dst, 0, 1).numpy()
dst = (dst/(np.max(dst) - np.min(dst)) + 1)/2
print(dst)
cv.imshow('test', dst)
cv.waitKey(0)