matlab vision mean,Pytorch的mean和std调查代码实例

本篇文章小编给大家分享一下Pytorch的mean和std调查代码实例,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。

代码如下所示:

# coding: utf-8

from __future__ import print_function

import copy

import click

import cv2

import numpy as np

import torch

from torch.autograd import Variable

from torchvision import models, transforms

import matplotlib.pyplot as plt

import load_caffemodel

import scipy.io as sio

# if model has LSTM

# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'

imgname = 'S055_002_00000025.png' # anger

image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]

std_file = [0.229, 0.224, 0.225]

raw_image = cv2.imread(image_path)[..., ::-1]

print(raw_image.shape)

raw_image = cv2.resize(raw_image, (224, ) * 2)

image = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(

mean=mean_file,

std =std_file,

#mean = mean_file,

#std = std_file,

)

])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()

convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W

convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))

convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C

print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1

err = np.max(diff)

print(err)

plt.imshow(np.uint8(convert_image1))

plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]

std_file = [0.229, 0.224, 0.225]

# coding: utf-8

import matplotlib.pyplot as plt

import argparse

import os

import numpy as np

import torchvision

import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')

parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,

help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)

args.dataset = 'cifar10'

if args.dataset == "cifar10":

train_transform = transforms.Compose([transforms.ToTensor()])

train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)

#print(vars(train_set))

print(train_set.train_data.shape)

print(train_set.train_data.mean(axis=(0,1,2))/255)

print(train_set.train_data.std(axis=(0,1,2))/255)

# imshow image

train_data = train_set.train_data

ind = 100

img0 = train_data[ind,...]

## test channel number, in total , the correct channel is : RGB,not like BGR in caffe

# error produce

#b,g,r=cv2.split(img0)

#img0=cv2.merge([r,g,b])

print(img0.shape)

print(type(img0))

plt.imshow(img0)

plt.show() # in ship in sea

#img0 = cv2.resize(img0,(224,224))

#cv2.imshow('img0',img0)

#cv2.waitKey()

elif args.dataset == "cifar100":

train_transform = transforms.Compose([transforms.ToTensor()])

train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)

#print(vars(train_set))

print(train_set.train_data.shape)

print(np.mean(train_set.train_data, axis=(0,1,2))/255)

print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":

train_transform = transforms.Compose([transforms.ToTensor()])

train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)

#print(vars(train_set))

print(list(train_set.train_data.size()))

print(train_set.train_data.float().mean()/255)

print(train_set.train_data.float().std()/255)

结果:

cifar10

Files already downloaded and verified

(50000, 32, 32, 3)

[ 0.49139968 0.48215841 0.44653091]

[ 0.24703223 0.24348513 0.26158784]

(32, 32, 3)

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');

train_data = data.train_data;

disp(size(train_data));

temp = mean(train_data,1);

disp(size(temp));

train_data = double(train_data);

% compute mean_file

mean_val = mean(mean(mean(train_data,1),2),3)/255;

% compute std_file

temp1 = train_data(:,:,:,1);

std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);

std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);

std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);

std_val = [std_val1, std_val2, std_val3];

disp(mean_val);

disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]

% std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值