参
数
配
置
参数配置
参数配置
from major_models.LeNet import LeNet
import torch
dict_label = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4,"5": 5,
"6": 6, "7": 7, "8": 8, "9": 9}
batchsize = 2
num_epoch = 2
crop_size = (32, 32)
train_image = r"./major_dataset_repo/split_data/train"
val_image = r'./major_dataset_repo/split_data/valid'
test_image = r'./major_dataset_repo/split_data/test'
dataset_image = r'./major_dataset_repo/image'
path_test_model = "./major_saved_models_repo/common/weights/best_model.pth"
path_predict_model = "./major_saved_models_repo/common/weights/best_model.pth"
path_saved_model = './major_saved_models_repo/common/weights/best_model.pth'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
norm_mean = [0.45115253, 0.48260283, 0.49052352]
norm_std = [0.26216552, 0.24431673, 0.24694261]
model = LeNet(num_classes=6,num_linear=44944)
import random
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
random.seed(1)
dict_label = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4,"5": 5,
"6": 6, "7": 7, "8": 8, "9": 9}
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.png'), img_names))
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = dict_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info
class MyDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.label_name = dict_label
self.data_info = get_img_info(data_dir)
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.data_info)
train_dir = os.path.join('.', 'split_data',"train")
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
])
train_data = MyDataset(data_dir=train_dir, transform=train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=3000, shuffle=True)
train = iter(train_loader).next()[0]
train_mean = np.mean(train.numpy(), axis=(0, 2, 3))
train_std = np.std(train.numpy(), axis=(0, 2, 3))
print("train_mean:",train_mean)
print("train_std:",train_std)
pytorch的transforms.Normalize(mean,std)是每个通道减去均值,除以标准差获得归一化图像
import torch
from torchvision import transforms
C= 2
H= 2
W= 2
arry = torch.arange(C*H*W,dtype=torch.float32).view([C,H,W])
print('输入矩阵:',arry)
mean = [2,2]
std = [2,2]
n = transforms.Normalize(mean=mean,std=std)
print('pytorch标准化:',n(arry))
print('公式标准化:')
for c in range(C):
print((arry[c]-mean[c]) / std[c])