任务背景
利用resnet18网络结构及预训练模型参数进行102类别的花数据集分类,迁移学习冻结resnet18输出层外的权重参数更新,保存最好的训练模型pt文件。
数据如下:
花数据集+json文件+2个迁移学习模型
一、导入库
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
# torchvision中的transforms模块自带数据增强、数据预处理功能;models预训练模型,如resnet模型;datasets文件夹
from torchvision import transforms, models, datasets
import imageio
import time
import warnings
warnings.filterwarnings("ignore")
import random
import sys
import copy
import json
from PIL import Image
二、数据导入
# 导入数据路径——改为自己的数据路径
data_dir = r'D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
三、数据预处理
# 数据集量较少,需进行数据增强(Data Augmentation):旋转、裁剪、翻转(水平/垂直)、平移
data_transforms = {
'train':
# Compose()顺序执行以下操作
transforms.Compose([
# 数据集尺寸不同,统一尺寸,可正方形(常用)、长方形,一般64、128、224、256
# 数据越小易损失特征,但训练速度加快;CPU一般96/64
transforms.Resize([96, 96]),
transforms.RandomRotation(45), #随机旋转,-45到45度之间随机选
transforms.CenterCrop(64), #从中心开始裁剪,实际输入网络的图像大小以裁剪后大小为准,如裁剪后为64*64,则输入大小为64*64
transforms.RandomHorizontalFlip(p=0.5), #随机水平翻转 选择一个概率概率,每张图都有50%可能性水平翻转
transforms.RandomVerticalFlip(p=0.5), #随机垂直翻转,每张图都有50%可能性垂直翻转
# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#(极端光照条件下使用,一般不使用)参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
# transforms.RandomGrayscale(p=0.025),#(一般不使用)概率转换成灰度率,3通道就是R=G=B,转成RRR/GGG/BBB
transforms.ToTensor(), #数据转换为tensor格式
# 标准化操作。由于数据量较少不具有代表性,选用自己的均值标准差结果不稳定,因此选择大数据中提供的标准差
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#r、g、b三通道的均值u,标准差b,(x-u)/b
]),
# 验证集测试模型实际训练结果,不需要数据增强
'valid':
transforms.Compose([
# 与训练集输入尺寸(裁剪后尺寸)相同
transforms.Resize([64, 64]),
transforms.ToTensor(),
# 与训练集均值标准差相同
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# 由于输入图像尺寸较小,batch_size可以大些。考虑电脑显存问题
batch_size = 128
# 将训练集和验证集文件夹与图像预处理操作对应起来————————image_datasets为字典类型,datasets数据以文件夹形式处理(数据集中的数据按类别划分了单独文件夹,因此以文件夹形式处理)
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
# print(image_datasets)
'''
{'train': Dataset ImageFolder
Number of datapoints: 6552
Root location: D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data\train
StandardTransform
Transform: Compose(
Resize(size=[96, 96], interpolation=bilinear, max_size=None, antialias=True)
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
CenterCrop(size=(64, 64))
RandomHorizontalFlip(p=0.5)
RandomVerticalFlip(p=0.5)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
), 'valid': Dataset ImageFolder
Number of datapoints: 818
Root location: D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data\valid
StandardTransform
Transform: Compose(
Resize(size=[64, 64], interpolation=bilinear, max_size=None, antialias=True)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)}
'''
# 将训练集和验证集文件夹中的数据打乱划分batch——————dataloaders为字典形式
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
# print(dataloaders)
# {'train': <torch.utils.data.dataloader.DataLoader object at 0x000002425BA07130>, 'valid': <torch.utils.data.dataloader.DataLoader object at 0x000002425BA07100>}
# 计算训练集和验证集分别数据数目
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
# 获取分类标签
class_names = image_datasets['train'].classes
# print(class_names)
'''
['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26',
'27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47',
'48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68',
'69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89',
'9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']
'''
# 打开cat_to_name.json文件,文件中有数字对应的实际类别名称——数据标签文件
with open('D:/咕泡人工智能-配套资料\配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)
# print(cat_to_name)
'''
{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster', '27': 'prince of wales feathers',
'7': 'moon orchid', '16': 'globe-flower', '25': 'grape hyacinth', '26': 'corn poppy', '79': 'toad lily', '39': 'siam tulip', '24': 'red ginger',
'67': 'spring crocus', '35': 'alpine sea holly', '32': 'garden phlox', '10': 'globe thistle', '6': 'tiger lily', '93': 'ball moss', '33': 'love in the mist',
'9': 'monkshood', '102': 'blackberry lily', '14': 'spear thistle', '19': 'balloon flower', '100': 'blanket flower', '13': 'king protea', '49': 'oxeye daisy',
'15': 'yellow iris', '61': 'cautleya spicata', '31': 'carnation', '64': 'silverbush', '68': 'bearded iris', '63': 'black-eyed susan', '69': 'windflower',
'62': 'japanese anemone', '20': 'giant white arum lily', '38': 'great masterwort', '4': 'sweet pea', '86': 'tree mallow', '101': 'trumpet creeper',
'42': 'daffodil', '22': 'pincushion flower', '2': 'hard-leaved pocket orchid', '54': 'sunflower', '66': 'osteospermum', '70': 'tree poppy', '85': 'desert-rose',
'99': 'bromelia', '87': 'magnolia', '5': 'english marigold', '92': 'bee balm', '28': 'stemless gentian', '97': 'mallow', '57': 'gaura', '40': 'lenten rose',
'47': 'marigold', '59': 'orange dahlia', '48': 'buttercup', '55': 'pelargonium', '36': 'ruby-lipped cattleya', '91': 'hippeastrum', '29': 'artichoke', '71': 'gazania',
'90': 'canna lily', '18': 'peruvian lily', '98': 'mexican petunia', '8': 'bird of paradise', '30': 'sweet william', '17': 'purple coneflower', '52': 'wild pansy',
'84': 'columbine', '12': "colt's foot", '11': 'snapdragon', '96': 'camellia', '23': 'fritillary', '50': 'common dandelion', '44': 'poinsettia', '53': 'primula',
'72': 'azalea', '65': 'californian poppy', '80': 'anthurium', '76': 'morning glory', '37': 'cape flower', '56': 'bishop of llandaff', '60': 'pink-yellow dahlia',
'82': 'clematis', '58': 'geranium', '75': 'thorn apple', '41': 'barbeton daisy', '95': 'bougainvillea', '43': 'sword lily', '83': 'hibiscus', '78': 'lotus lotus',
'88': 'cyclamen', '94': 'foxglove', '81': 'frangipani', '74': 'rose', '89': 'watercress', '73': 'water lily', '46': 'wallflower', '77': 'passion flower',
'51': 'petunia'}
'''
四、导入Resnet网络——迁移学习
# 迁移学习:使用前人的网络结构和模型做训练
# 数据量较小时:对模型进行微小改动,如冻住某一部分不进行迭代更新训练,只训练更新较少的网络层
# 数据量中等时:对模型进行改动,如冻住少量部分不进行迭代更新训练,其他部分进行迭代更新训练
# 数据量较大时,整个模型不进行冻结,全部更新训练
# 加载models中提供的模型,并且直接用训练好的权重当作初始化参数
#可选网络结构比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],resnet网络效果较好
model_name = 'resnet'
# 是否用人家训练好的特征来做
# 此项目冻结输出层前面所有部分,不进行训练更新
feature_extract = True
# 是否用GPU训练——————固定写法
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('CUDA is not available. Training on CPU ...')
else:
print('CUDA is available! Training on GPU ...')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 使用18层的resnet网络结构,18层的能快点,条件好点的也可以选152
model_ft = models.resnet18()
# resnet18网络结构
# print(model_ft)
'''
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
# resnet18输出层为1000个类别,此次项目预测类别为102,后续修改输出层
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
'''
# 自定义函数判断模型参数是否要进行更新
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
# 对18层resnet网络结构中的权重参数进行遍历
for param in model.parameters():
# requires_grad反向传播时是否更新参数,False不进行更新
param.requires_grad = False
# for param in model_ft.parameters():
# param.requires_grad = False
# for name, param in model_ft.named_parameters():
# # 输出所有权重参数名称及参数值
# print(name)
# print(param)
# 输出层权重参数如下:
'''
fc.weight
Parameter containing:
tensor([[ 0.0018, -0.0213, -0.0421, ..., 0.0064, -0.0055, -0.0221],
[-0.0298, -0.0036, 0.0277, ..., 0.0181, -0.0311, -0.0340],
[-0.0434, -0.0196, -0.0026, ..., 0.0229, 0.0110, 0.0345],
...,
[-0.0097, 0.0290, 0.0391, ..., -0.0091, -0.0028, -0.0373],
[ 0.0073, 0.0256, 0.0238, ..., 0.0004, -0.0275, -0.0347],
[-0.0356, -0.0144, -0.0418, ..., 0.0020, -0.0338, 0.0351]])
fc.bias
Parameter containing:
tensor([ 0.0226, 0.0271, -0.0001, -0.0043, -0.0339, -0.0216, 0.0056, -0.0067,
-0.0024, 0.0062, -0.0425, 0.0148, 0.0340, -0.0365, 0.0084, 0.0407,
-0.0335, 0.0133, 0.0155, -0.0303, -0.0304, 0.0291, -0.0440, -0.0278,
-0.0324, -0.0041, 0.0068, -0.0349, -0.0380, 0.0169, 0.0203, -0.0439,
0.0185, -0.0206, 0.0164, -0.0020, 0.0067, -0.0101, -0.0166, 0.0155,
-0.0302, 0.0095, -0.0119, 0.0124, 0.0065, 0.0435, -0.0298, -0.0022,
-0.0379, 0.0406, -0.0065, -0.0210, 0.0151, 0.0111, -0.0251, -0.0374,
-0.0246, -0.0052, -0.0139, -0.0246, 0.0008, 0.0318, 0.0358, 0.0370,
0.0277, -0.0124, 0.0391, -0.0158, 0.0052, -0.0437, 0.0076, -0.0352,
-0.0045, 0.0270, -0.0332, -0.0177, -0.0420, -0.0116, -0.0224, 0.0410,
-0.0063, 0.0408, -0.0062, -0.0034, -0.0084, -0.0306, 0.0077, 0.0284,
0.0437, -0.0336, -0.0037, -0.0261, -0.0234, -0.0166, 0.0153, -0.0213,
-0.0396, 0.0089, 0.0307, 0.0027, -0.0179, -0.0098, 0.0343, 0.0375,
-0.0432, 0.0126, -0.0060, 0.0370, -0.0030, 0.0219, 0.0140, -0.0165,
0.0409, -0.0134, -0.0352, -0.0265, -0.0312, -0.0205, 0.0268, 0.0429,
-0.0260, 0.0252, 0.0016, -0.0357, 0.0052, -0.0040, 0.0173, -0.0442,
-0.0089, -0.0116, 0.0166, -0.0381, 0.0143, 0.0032, -0.0010, -0.0092,
0.0130, -0.0224, -0.0258, 0.0187, 0.0179, 0.0061, 0.0297, 0.0404,
0.0218, 0.0116, 0.0345, -0.0209, -0.0256, 0.0091, -0.0302, -0.0306,
0.0239, 0.0412, 0.0408, 0.0253, -0.0016, -0.0182, -0.0015, 0.0042,
0.0150, -0.0024, 0.0138, -0.0164, 0.0076, -0.0181, 0.0339, 0.0179,
-0.0289, -0.0327, -0.0399, 0.0419, 0.0386, -0.0345, -0.0321, -0.0413,
0.0390, -0.0339, 0.0359, 0.0012, 0.0298, 0.0134, -0.0128, -0.0251,
0.0435, -0.0231, -0.0432, 0.0385, -0.0423, -0.0137, 0.0170, -0.0412,
-0.0413, 0.0114, 0.0428, -0.0425, -0.0089, -0.0290, -0.0112, 0.0277,
-0.0109, 0.0083, 0.0324, -0.0163, -0.0389, -0.0206, 0.0052, -0.0091,
0.0151, 0.0244, 0.0010, -0.0173, -0.0015, -0.0437, -0.0377, -0.0107,
-0.0185, 0.0059, -0.0198, -0.0395, -0.0129, -0.0363, 0.0408, 0.0418,
-0.0230, -0.0122, -0.0168, 0.0281, 0.0338, 0.0321, 0.0219, -0.0041,
0.0307, -0.0244, -0.0007, -0.0307, 0.0387, 0.0051, 0.0417, -0.0241,
-0.0165, 0.0186, 0.0210, 0.0373, 0.0192, 0.0415, -0.0230, -0.0091,
-0.0324, -0.0416, 0.0304, 0.0065, 0.0398, 0.0036, -0.0232, -0.0392,
0.0109, -0.0108, -0.0320, -0.0032, -0.0138, 0.0357, -0.0247, 0.0363,
-0.0185, -0.0197, -0.0068, -0.0120, -0.0377, -0.0101, 0.0210, -0.0243,
0.0269, 0.0128, -0.0142, -0.0385, 0.0185, 0.0233, -0.0051, -0.0172,
0.0224, -0.0434, 0.0078, 0.0159, -0.0201, -0.0363, -0.0246, 0.0044,
-0.0306, -0.0377, -0.0313, 0.0366, 0.0368, -0.0303, 0.0066, 0.0322,
-0.0143, 0.0343, 0.0233, -0.0337, -0.0211, -0.0060, -0.0167, -0.0189,
-0.0236, 0.0292, 0.0194, 0.0372, -0.0055, 0.0430, 0.0243, -0.0126,
0.0208, 0.0273, 0.0145, 0.0269, 0.0020, -0.0070, -0.0102, 0.0016,
-0.0191, 0.0397, -0.0001, -0.0044, -0.0360, 0.0095, 0.0357, 0.0089,
0.0235, -0.0244, 0.0088, 0.0222, 0.0259, 0.0096, 0.0189, 0.0390,
0.0401, -0.0208, -0.0358, -0.0197, -0.0248, -0.0088, 0.0085, 0.0018,
-0.0132, 0.0289, 0.0287, -0.0400, 0.0063, -0.0054, 0.0399, 0.0123,
0.0102, -0.0326, 0.0194, -0.0049, 0.0104, -0.0171, -0.0080, 0.0429,
-0.0056, -0.0298, 0.0064, 0.0341, -0.0191, 0.0132, -0.0174, 0.0435,
0.0035, 0.0093, -0.0321, -0.0366, 0.0307, 0.0088, -0.0395, 0.0357,
0.0032, -0.0149, -0.0247, 0.0124, 0.0436, 0.0126, -0.0156, 0.0050,
0.0109, 0.0183, -0.0404, -0.0018, -0.0104, -0.0395, 0.0390, -0.0306,
0.0261, -0.0244, -0.0253, -0.0329, 0.0334, 0.0429, -0.0138, -0.0190,
-0.0235, -0.0204, -0.0393, 0.0217, -0.0332, -0.0438, -0.0294, -0.0158,
-0.0103, 0.0238, -0.0419, -0.0408, 0.0113, 0.0366, 0.0221, -0.0190,
-0.0244, 0.0335, 0.0102, -0.0101, -0.0111, -0.0284, -0.0155, -0.0114,
0.0137, 0.0019, -0.0006, -0.0074, 0.0137, -0.0260, -0.0037, 0.0199,
0.0155, -0.0296, 0.0173, 0.0224, 0.0091, -0.0167, 0.0004, 0.0206,
-0.0237, -0.0195, 0.0387, -0.0045, 0.0088, 0.0261, -0.0418, -0.0144,
-0.0375, -0.0106, -0.0354, 0.0411, -0.0053, -0.0248, -0.0010, 0.0323,
-0.0203, 0.0012, 0.0204, -0.0320, 0.0321, 0.0137, 0.0064, -0.0329,
0.0051, -0.0340, 0.0171, 0.0422, 0.0266, 0.0238, -0.0164, 0.0103,
-0.0413, -0.0355, 0.0127, 0.0207, -0.0240, 0.0398, 0.0323, 0.0217,
0.0030, 0.0396, 0.0327, -0.0060, 0.0312, -0.0117, -0.0079, 0.0095,
0.0423, 0.0010, 0.0018, 0.0233, 0.0434, -0.0210, 0.0049, -0.0072,
0.0031, 0.0052, -0.0292, -0.0217, -0.0253, 0.0274, -0.0230, -0.0342,
-0.0149, 0.0137, -0.0057, 0.0344, -0.0327, 0.0370, 0.0142, -0.0194,
0.0109, 0.0366, -0.0046, -0.0203, 0.0088, -0.0117, 0.0263, 0.0020,
-0.0335, 0.0387, -0.0196, 0.0386, -0.0433, -0.0012, 0.0138, -0.0383,
0.0059, -0.0313, -0.0404, -0.0103, -0.0105, -0.0196, -0.0297, 0.0331,
0.0372, -0.0275, 0.0007, 0.0266, 0.0006, 0.0255, -0.0035, -0.0365,
0.0165, -0.0423, 0.0054, 0.0041, 0.0026, -0.0338, 0.0436, -0.0385,
0.0346, -0.0295, -0.0315, 0.0096, 0.0376, 0.0166, -0.0351, 0.0077,
0.0028, 0.0139, 0.0394, 0.0115, -0.0231, -0.0134, 0.0167, 0.0160,
-0.0003, -0.0152, 0.0399, 0.0367, -0.0424, 0.0104, -0.0025, -0.0118,
0.0020, -0.0115, -0.0253, -0.0290, 0.0042, -0.0025, -0.0155, 0.0101,
-0.0170, 0.0135, 0.0283, 0.0441, 0.0294, -0.0083, -0.0428, -0.0267,
-0.0247, -0.0344, 0.0177, 0.0173, 0.0029, -0.0369, -0.0150, -0.0193,
0.0428, -0.0401, 0.0377, 0.0138, -0.0136, -0.0104, 0.0325, 0.0335,
-0.0152, -0.0014, -0.0287, 0.0375, -0.0426, 0.0393, 0.0016, -0.0244,
0.0394, 0.0212, -0.0019, -0.0024, -0.0212, 0.0249, -0.0244, -0.0144,
0.0227, 0.0073, 0.0412, -0.0194, -0.0300, 0.0084, -0.0273, 0.0157,
-0.0270, -0.0411, 0.0153, -0.0040, -0.0268, 0.0048, 0.0164, -0.0165,
-0.0089, 0.0328, 0.0345, -0.0013, -0.0226, -0.0097, -0.0234, 0.0156,
0.0236, 0.0076, -0.0349, -0.0442, -0.0293, 0.0132, -0.0420, -0.0020,
0.0059, -0.0231, -0.0187, -0.0304, 0.0270, 0.0382, 0.0407, -0.0244,
0.0394, -0.0204, -0.0356, -0.0287, -0.0221, -0.0330, 0.0249, 0.0247,
0.0250, -0.0180, -0.0182, -0.0358, -0.0196, -0.0441, 0.0407, -0.0027,
0.0177, 0.0167, -0.0258, -0.0097, -0.0210, -0.0106, -0.0331, 0.0008,
-0.0392, -0.0378, 0.0418, 0.0093, -0.0442, 0.0419, -0.0233, 0.0076,
0.0394, 0.0428, 0.0027, -0.0117, -0.0161, 0.0153, 0.0035, 0.0222,
0.0375, -0.0300, 0.0200, 0.0285, -0.0120, 0.0074, 0.0054, -0.0117,
-0.0161, -0.0088, 0.0428, 0.0103, 0.0136, 0.0376, -0.0172, 0.0086,
0.0183, 0.0381, 0.0109, -0.0177, -0.0416, -0.0351, 0.0338, -0.0404,
0.0324, 0.0041, -0.0172, 0.0002, -0.0415, 0.0423, 0.0172, -0.0068,
0.0370, 0.0321, 0.0040, 0.0401, 0.0030, 0.0238, 0.0102, -0.0437,
0.0204, 0.0432, 0.0399, -0.0186, -0.0067, 0.0205, -0.0098, 0.0009,
0.0019, 0.0317, -0.0237, 0.0062, 0.0097, -0.0312, -0.0033, 0.0028,
-0.0021, 0.0006, -0.0320, -0.0196, 0.0363, 0.0285, 0.0321, -0.0132,
0.0344, -0.0232, 0.0379, -0.0166, 0.0311, -0.0174, 0.0431, 0.0006,
-0.0066, -0.0344, -0.0076, 0.0245, 0.0286, -0.0388, 0.0114, 0.0204,
0.0137, 0.0387, 0.0206, -0.0392, -0.0109, 0.0375, 0.0269, 0.0232,
-0.0362, 0.0235, -0.0137, 0.0303, 0.0389, -0.0068, 0.0306, 0.0273,
0.0264, -0.0074, 0.0315, -0.0291, -0.0027, -0.0061, 0.0188, -0.0123,
-0.0360, -0.0266, 0.0292, 0.0248, 0.0127, -0.0251, -0.0426, 0.0066,
-0.0005, -0.0162, -0.0236, -0.0330, 0.0339, 0.0319, 0.0135, 0.0260,
-0.0389, -0.0375, -0.0192, -0.0079, -0.0066, -0.0261, -0.0441, -0.0042,
0.0086, 0.0291, 0.0283, 0.0028, -0.0137, -0.0218, 0.0109, -0.0052,
-0.0213, -0.0108, -0.0354, -0.0012, 0.0006, -0.0111, -0.0066, 0.0401,
-0.0313, 0.0203, 0.0060, 0.0339, 0.0072, 0.0148, 0.0047, 0.0363,
-0.0117, 0.0338, 0.0238, -0.0367, 0.0093, 0.0048, 0.0222, 0.0164,
0.0399, 0.0005, 0.0056, 0.0023, -0.0330, 0.0241, -0.0403, 0.0047,
0.0312, -0.0148, 0.0146, 0.0268, 0.0439, -0.0265, -0.0044, 0.0441,
0.0395, -0.0164, 0.0117, 0.0343, 0.0355, 0.0243, -0.0091, 0.0083,
-0.0213, -0.0196, -0.0082, 0.0070, 0.0403, -0.0407, 0.0270, 0.0375,
0.0207, 0.0207, -0.0049, 0.0020, 0.0429, -0.0432, -0.0368, -0.0040,
0.0035, 0.0114, -0.0345, -0.0286, 0.0232, 0.0342, 0.0437, 0.0193,
0.0030, 0.0375, 0.0161, 0.0172, -0.0069, -0.0118, -0.0235, -0.0155,
-0.0029, -0.0035, 0.0372, -0.0343, 0.0183, -0.0057, -0.0093, -0.0322,
-0.0303, 0.0275, -0.0364, -0.0240, -0.0090, -0.0058, -0.0055, 0.0315,
-0.0020, 0.0268, -0.0305, -0.0286, -0.0083, 0.0015, -0.0226, 0.0249,
-0.0133, -0.0359, 0.0393, 0.0058, -0.0354, 0.0011, 0.0424, 0.0363,
0.0405, 0.0006, -0.0422, 0.0363, -0.0298, -0.0319, -0.0131, 0.0021,
0.0276, -0.0302, -0.0350, -0.0433, 0.0185, 0.0263, 0.0307, -0.0093,
0.0377, 0.0031, -0.0115, -0.0297, -0.0327, 0.0103, 0.0179, -0.0071,
0.0029, -0.0345, -0.0335, -0.0184, 0.0426, 0.0169, 0.0039, -0.0071,
0.0421, -0.0185, -0.0235, -0.0288, -0.0305, -0.0199, 0.0091, -0.0162,
-0.0269, 0.0172, 0.0330, -0.0416, -0.0347, -0.0392, -0.0148, -0.0167])
'''
# 输出可以进行梯度更新的权重参数名称——————————因为前面修改了所有的参数更新为False,此判断无输出
# if param.requires_grad == True:
# print('\t', name)
五、模型训练及参数配置
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
# 加载预训练模型
# resnet有18、50、101、152等层
# model_name.resnet18只有网络结构,没有预训练参数。pretrained是否加载预训练模型,pretrained=True下载resnet18到C:/user/cache
model_ft = model_name.resnet18(pretrained=use_pretrained)
# 冻结网络结构中所有的参数更新
set_parameter_requires_grad(model_ft, feature_extract)
# 找到全连接层的输入参数:512
# resnet18的全连接层:(fc): Linear(in_features=512, out_features=1000, bias=True)
num_ftrs = model_ft.fc.in_features
# num_classes自己的任务类别数,覆盖resnet18中的全连接输出,此全连接网络中的参数可以进行梯度更新
model_ft.fc = nn.Linear(num_ftrs, num_classes)
# # 输入大小根据自己配置来
# input_size = 64
# return model_ft, input_size
return model_ft
# input_size好像没用处,删去
# model_ft, input_size = initialize_model(models, 102, feature_extract, use_pretrained=True)
model_ft = initialize_model(models, 102, feature_extract, use_pretrained=True)
#GPU还是CPU计算
model_ft = model_ft.to(device)
# 模型保存路径,名字自己起——————网络结构和权重参数
filename = 'D:/咕泡人工智能-配套资料/配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/best_my.pt'
# 是否训练所有层
# 将model_ft的所有权重参数保存到params_to_update
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
params_to_update = []
for name, param in model_ft.named_parameters():
# 此项目中只有新加的全连接层param.requires_grad为True,进行参数更新
if param.requires_grad == True:
params_to_update.append(param)
print("\t", name)
else:
for name, param in model_ft.named_parameters():
if param.requires_grad == True:
print("\t", name)
# 优化器设置,只训练params_to_update中的参数
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)
# 学习率衰减策略:学习率每step_size个epoch衰减成原来的gamma倍
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
# 损失函数
criterion = nn.CrossEntropyLoss()
# 训练函数
def train_model(model, dataloaders, criterion, optimizer, filename, num_epochs=25):
# 计算训练起始时间
since = time.time()
# 记录训练最好的那一次的准确率
best_acc = 0
# 判断模型放到CPU或者GPU
model.to(device)
# 训练过程中打印一堆损失和指标
# 验证准确率
val_acc_history = []
# 训练准确率
train_acc_history = []
# 训练损失
train_losses = []
# 验证损失
valid_losses = []
# 当前学习率 optimizer.param_groups是一个字典结构
LRs = [optimizer.param_groups[0]['lr']]
# print(optimizer.param_groups)
'''
[{'params': [Parameter containing:
tensor([[-0.0065, -0.0276, 0.0032, ..., -0.0383, -0.0091, 0.0323],
[ 0.0141, 0.0429, -0.0220, ..., 0.0234, 0.0301, -0.0281],
[ 0.0064, -0.0427, 0.0039, ..., 0.0214, -0.0171, -0.0016],
...,
[-0.0355, 0.0126, -0.0099, ..., -0.0322, -0.0201, 0.0245],
[-0.0127, 0.0114, -0.0213, ..., 0.0270, -0.0070, -0.0315],
[-0.0226, -0.0235, 0.0262, ..., -0.0109, 0.0241, 0.0084]],
requires_grad=True), Parameter containing:
tensor([-0.0118, 0.0029, -0.0184, 0.0226, 0.0082, -0.0320, -0.0046, 0.0358,
-0.0234, 0.0430, 0.0245, 0.0431, -0.0127, -0.0231, 0.0230, 0.0357,
-0.0181, 0.0389, 0.0127, 0.0343, 0.0044, 0.0217, -0.0323, -0.0211,
0.0309, 0.0416, -0.0317, -0.0248, 0.0093, -0.0324, -0.0115, 0.0181,
-0.0190, -0.0005, 0.0418, -0.0369, -0.0144, -0.0229, -0.0295, -0.0048,
0.0088, -0.0371, -0.0203, -0.0163, 0.0073, 0.0044, -0.0410, -0.0289,
-0.0305, -0.0363, 0.0409, 0.0364, 0.0082, 0.0419, -0.0063, -0.0100,
0.0008, -0.0270, -0.0163, 0.0059, -0.0100, 0.0252, 0.0183, -0.0160,
0.0027, 0.0347, -0.0131, -0.0292, -0.0225, -0.0183, 0.0326, -0.0062,
-0.0422, -0.0220, -0.0410, -0.0408, 0.0405, -0.0046, -0.0339, 0.0411,
-0.0015, -0.0371, -0.0152, 0.0244, -0.0128, -0.0117, -0.0275, 0.0333,
0.0033, 0.0276, 0.0302, -0.0367, 0.0236, 0.0409, 0.0192, -0.0411,
-0.0224, -0.0200, -0.0321, 0.0120, -0.0427, 0.0333],
requires_grad=True)], 'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False,
'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.01}]
'''
# 查看优化器参数
# print(optimizer)
'''
Adam (
Parameter Group 0
amsgrad: False
betas: (0.9, 0.999)
capturable: False
differentiable: False
eps: 1e-08
foreach: None
fused: None
initial_lr: 0.01
lr: 0.01
maximize: False
weight_decay: 0
)
'''
# 最好的那次模型,后续会变的,先初始化————————复制当前的权重参数,model.state_dict()模型当前权重参数
best_model_wts = copy.deepcopy(model.state_dict())
# 一个个epoch来遍历
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# 训练和验证
for phase in ['train', 'valid']:
if phase == 'train':
model.train() # 训练
else:
model.eval() # 验证
# 初始化损失和预测正确个数
running_loss = 0.0
running_corrects = 0
# 把数据都取个遍——————dataloaders字典结构,根据phase关键字决定取哪一部分
for inputs, labels in dataloaders[phase]:
# to(device)数据放到你的CPU或GPU
inputs = inputs.to(device)
labels = labels.to(device)
# 梯度清零
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
# 102个类别中概率最大值的下标
_, preds = torch.max(outputs, 1)
# 训练阶段更新权重
if phase == 'train':
loss.backward()
optimizer.step() #完成一次迭代
# 累加计算总损失和总准确率
# input格式为(batch, c, h, w),inputs.size(0)表示batch那个维度
running_loss += loss.item() * inputs.size(0)
# 预测结果最大的和真实值是否一致
running_corrects += torch.sum(preds == labels.data)
# 计算每个epoch的损失和准确率
epoch_loss = running_loss / len(dataloaders[phase].dataset) # 算平均
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
# 一个epoch需要多少时间
time_elapsed = time.time() - since
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# 得到最好那次的模型
if phase == 'valid' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
state = {
'state_dict': model.state_dict(), # 字典里key就是各层的名字,值就是训练好的权重
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
}
torch.save(state, filename)
if phase == 'valid':
val_acc_history.append(epoch_acc)
valid_losses.append(epoch_loss)
# scheduler.step(epoch_loss)#学习率衰减
if phase == 'train':
train_acc_history.append(epoch_acc)
train_losses.append(epoch_loss)
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
LRs.append(optimizer.param_groups[0]['lr'])
print()
scheduler.step() # 学习率衰减
# 总体运行花了多少时间
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# 训练完后用最好的一次当做模型最终的结果,等着一会测试
model.load_state_dict(best_model_wts)
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
# 训练模型
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, filename, num_epochs=20)
训练过程展示:
六、验证最佳训练模型
# 加载最佳训练模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
# 随机得到一个batch的验证数据进行测试
dataiter = iter(dataloaders['valid'])
images, labels = next(dataiter)
model_ft.eval()
if train_on_gpu:
output = model_ft(images.cuda())
else:
output = model_ft(images)
# 得到概率最大的一个
_, preds_tensor = torch.max(output, 1)
# 判断数据是否在GPU,在的话数据转cpu中转ndarray类型,在的话直接转ndarray类型
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
# 展示预测结果
def im_convert(tensor):
""" 展示数据"""
# 数据从cpu中克隆一份出来
image = tensor.to("cpu").clone().detach()
# 数据从tensor格式转为ndarray
# numpy.squeeze(a, axis=None),用于从数组的形状中删除单维条目,其中a表示输入的数组,axis用于指定需要删除的维度。如果axis为空,则删除所有单维度的条目。
image = image.numpy().squeeze()
# PIL工具包
# tensor中数据格式为c*h*w,正常数据格式为h*w*c,transpose()将0、1、2代表tensor中三个维度c、h、w,转换为1、2、0即h、w、c格式
image = image.transpose(1, 2, 0)
# 反标准化操作,均值u,标准差b,x = x*b+u
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
print(image)
# clip()将数组中的元素值限制在给定的最小值和最大值之间,超出这个范围的值会被截断到最小值或最大值。
image = image.clip(0, 1)
print(image)
return image
fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 2
for idx in range(columns*rows):
ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
plt.imshow(im_convert(images[idx]))
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()
验证结果展示: