文章目录
引言
本篇博客可以看做是对:
pytorch实现手写英文字母识别 和 Pytorch搭建预训练VGG16实现10 Monkey Species Classification
这两篇博客中代码的重构和一些细节上的调整,并对功能类似的部分进行了模块化封装。
● 还未实现的部分:混合精度训练、多GPU训练等(待补充)。
● 任何逻辑不完善的地方,欢迎指出或讨论。
● 后续随缘更新。
完整代码可在github获取👇https://github.com/Scienthusiasts/Classification_pytorch
若对你有帮助,不妨star支持一下
1.数据集读取部分dataloader.py
1.1.分类数据集的数据组织形式
images
├─valid
│ ├─apple_pie
│ ├─baby_back_ribs
│ … …
│ └─waffles
└─train
├─apple_pie
├─baby_back_ribs … …
└─waffles
images
为图像根目录,train
为训练集图像, valid
为验证集图像,对应的子目录以类别命名,用于存储不同类别的图像。
1.2自定义数据增强/数据预处理类
数据增强/预处理方法分为三类,分别是训练时增强,验证时增强和测试时增强。训练时增强包含最全的数据增强操作,并依概率随机对每张图像执行;验证时增强只保留最基础的数据预处理方法,不包含数据增强,测试时增强只针对最终的可视化,不包含对图像的归一化处理。
class Transforms():
'''数据预处理/数据增强(基于albumentations库)
'''
def __init__(self, imgSize):
# 训练时增强
self.trainTF = A.Compose([
# 随机旋转
A.Rotate(limit=15, p=0.5),
# 最长边限制为imgSize
A.LongestMaxSize(max_size=imgSize),
# 随机镜像
A.HorizontalFlip(p=0.5),
# 参数:随机色调、饱和度、值变化
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),
# 随机明亮对比度
A.RandomBrightnessContrast(p=0.2),
# 高斯噪声
A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),
A.OneOf([
# 使用随机大小的内核将运动模糊应用于输入图像
A.MotionBlur(p=0.2),
# 中值滤波
A.MedianBlur(blur_limit=3, p=0.1),
# 使用随机大小的内核模糊输入图像
A.Blur(blur_limit=3, p=0.1),
], p=0.2),
# 较短的边做padding
A.PadIfNeeded(imgSize, imgSize, border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# 验证时增强
self.validTF = A.Compose([
# 最长边限制为imgSize
A.LongestMaxSize(max_size=imgSize),
# 较短的边做padding
A.PadIfNeeded(imgSize, imgSize, border_mode=0, mask_value=[0,0,0]),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# 可视化增强(只reshape)
self.visTF = A.Compose([
# 最长边限制为imgSize
A.LongestMaxSize(max_size=imgSize),
# 较短的边做padding
A.PadIfNeeded(imgSize, imgSize, border_mode=0, mask_value=[0,0,0]),
])
1.3.重写torch.utils.data.Dataset
数据集读取类
基本逻辑就是遍历数据集下每个类别对应的文件夹,并获取文件夹中的图像,图像的类别(标签)根据图像所在的文件夹划分。
class MyDataset(data.Dataset):
'''有监督分类任务对应的数据集读取方式
'''
def __init__(self, dir, mode, imgSize):
'''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径
Args:
:param dir: 图像数据集的根目录
:param mode: 模式(train/valid)
:param imgSize: 网络要求输入的图像尺寸
Returns:
precision, recall
'''
self.tf = Transforms(imgSize = imgSize)
# 记录数据集大小
self.dataSize = 0
# 数据集类别数
self.labelsNum = len(os.listdir(os.path.join(dir, mode)))
# 训练/验证
self.mode = mode
# 数据预处理方法
self.tf = Transforms(imgSize=imgSize)
# 遍历所有类别
self.imgPathList, self.labelList = [], []
'''对类进行排序,很重要!!!,否则会造成分类时标签匹配不上导致评估的精度很低(默认按字符串,如果类是数字还需要更改)'''
catDirs = sorted(os.listdir(os.path.join(dir, mode)))
for idx, cat in enumerate(catDirs):
catPath = os.path.join(dir, mode, cat)
labelFiles = os.listdir(catPath)
# 每个类别里图像数
length = len(labelFiles)
# 存放图片路径
self.imgPathList += [os.path.join(catPath, labelFiles[i]) for i in range(length)]
# 存放图片对应的标签(根据所在文件夹划分)
self.labelList += [idx for _ in range(length)]
self.dataSize += length
def __getitem__(self, item):
'''重载data.Dataset父类方法, 获取数据集中数据内容
'''
# 读取图片
img = Image.open(self.imgPathList[item]).convert('RGB')
img = np.array(img)
# 获取image对应的label
label = self.labelList[item]
# 数据预处理/数据增强
if self.mode=='train':
transformed = self.tf.trainTF(image=img)
if self.mode=='valid':
transformed = self.tf.validTF(image=img)
img = transformed['image']
return img.transpose(2,1,0), torch.LongTensor([label])
def __len__(self):
'''重载data.Dataset父类方法, 返回数据集大小
'''
return self.dataSize
def get_cls_num(self):
'''返回数据集类别数
'''
return self.labelsNum
1.4.模块测试样例
# for test only
if __name__ == '__main__':
datasetDir = 'E:/datasets/Classification/food-101/images'
mode = 'train'
bs = 64
seed = 22
seed_everything(seed)
train_data = MyDataset(datasetDir, mode, imgSize=224)
print(f'数据集大小:{train_data.__len__()}')
print(f'数据集类别数:{train_data.get_cls_num()}')
train_data_loader = data.DataLoader(dataset = train_data, batch_size=bs, shuffle=True)
# 获取label name
catNames = sorted(os.listdir(os.path.join(datasetDir, mode)))
# 可视化一个batch里的图像
from utils import visBatch
visBatch(train_data_loader, catNames)
# 输出数据格式
for step, batch in enumerate(train_data_loader):
print(batch[0].shape, batch[1].shape)
break
输出:
数据集大小:75750
数据集类别数:101
torch.Size([64, 3, 224, 224]) torch.Size([64, 1])
2.网络部分mynet.py
网络模块基于微调timm
库里提供的模型,基本的逻辑就是将原来模型的分类头的分类数替换为当前数据集的分类数,Backbone部分保持不变,并使用ImageNet预训练权重初始化,训练时可以冻结Backbone只训练分类头,或者微调整个网络。
timm
库里提供的模型名称和权重可以从huggingface中获取:https://huggingface.co/timm?sort_models=downloads#models
2.1.自定义分类网络torch.nn.Module
想要添加更多的Backbone,可以在modelList
和分支语句中添加相应内容:
class Model(nn.Module):
'''Backbone
'''
def __init__(self, catNums:int, modelType:str, loadckpt=False, pretrain=True, froze=True):
'''网络初始化
Args:
:param catNums: 数据集类别数
:param modelType: 使用哪个模型(timm库里的模型)
:param loadckpt: 是否导入模型权重
:param pretrain: 是否用预训练模型进行初始化(是则输入权重路径)
:param froze: 是否只训练分类头
Returns:
None
'''
super(Model, self).__init__()
# 模型接到线性层的维度
modelList = {
'mobilenetv3_small_100.lamb_in1k': 1024,
'mobilenetv3_large_100.ra_in1k': 1280,
'vit_base_patch16_224.augreg2_in21k_ft_in1k': 768,
'efficientnet_b5.sw_in12k_ft_in1k': 2048,
'resnet50.a1_in1k': 2048,
'vgg16.tv_in1k': 4096,
}
# 加载模型
self.backbone = timm.create_model(modelType, pretrained=pretrain)
# 删除原来的分类头并添加新的分类头(self.backbone就是去除了分类头的原始完整模型)
baseModel = modelType.split('.')[0]
if(baseModel in ['mobilenetv3_small_100', 'mobilenetv3_large_100', 'efficientnet_b5']):
self.backbone.classifier = nn.Identity()
self.head = nn.Linear(modelList[modelType], catNums)
if(baseModel=='vit_base_patch16_224'):
self.backbone.head = nn.Identity()
self.head = nn.Linear(modelList[modelType], catNums)
if(baseModel=='resnet50'):
self.backbone.fc = nn.Identity()
self.head = nn.Linear(modelList[modelType], catNums)
if(baseModel=='vgg16'):
self.backbone.head.fc = nn.Identity()
self.head = nn.Linear(modelList[modelType], catNums)
# 是否导入预训练权重
if loadckpt:
self.load_state_dict(torch.load(loadckpt))
# 是否只训练线性层
if froze:
for param in self.backbone.parameters():
param.requires_grad_(False)
def forward(self, x):
'''前向传播
'''
feat = self.backbone(x)
out = self.head(feat)
return out
2.2.模块测试样例
# for test only
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = Model(catNums=101, modelType='mobilenetv3_large_100.ra_in1k', pretrain=True, froze=True).to(device)
'''验证 1'''
# print(model)
'''验证 2'''
# summary(model, input_size=[(3, 224, 224)])
'''验证 3'''
x = torch.rand((4, 3, 600, 600)).to(device)
out = model(x)
print(out.shape)
3.训练/验证/测试模块runner.py
所有的训练、验证(一个epoch结束)和测试(推理一张图像)的pipeline集成在自定义的Runner
类中
3.1.Runner
类初始化
在Runner类的初始化阶段,一些与训练有关的模块会被定义,比如日志模块(用于训练时实时打印训练情况)、tensorboard模块、数据集,模型、损失函数、是否恢复断点等等。基于传入参数mode
的不同,初始化的模块也会有所不同。
def __init__(self, timm_model_name, img_size, ckpt_load_path, dataset_dir, epoch, bs, lr, log_dir, log_interval, pretrain, froze, optim_type, mode, resume=None, seed=0):
'''Runner初始化
Args:
:param timm_model_name: 模型名称(timm)
:param img_size: 统一图像尺寸的大小
:param ckpt_load_path: 预加载的权重路径
:param dataset_dir: 数据集根目录
:param eopch: 训练批次
:param bs: 训练batch size
:param lr: 学习率
:param log_dir: 日志文件保存目录
:param log_interval: 训练或验证时隔多少bs打印一次日志
:param pretrain: backbone是否用ImageNet预训练权重初始化
:param froze: 是否冻结Backbone只训练分类头
:param optim_type: 优化器类型
:param mode: 训练模式:train/eval/test
:param resume: 是否从断点恢复训练
:param seed: 固定全局种子
Returns:
None
'''
# 设置全局种子
seed_everything(seed)
self.timm_model_name = timm_model_name
self.img_size = img_size
self.ckpt_load_path = ckpt_load_path
self.dataset_dir = dataset_dir
self.epoch = epoch
self.bs = bs
self.lr = lr
self.log_dir = log_dir
self.log_interval = log_interval
self.pretrain = pretrain
self.froze = froze
self.mode = mode
self.optim_type = optim_type
self.cats = os.listdir(os.path.join(self.dataset_dir, 'valid'))
self.cls_num = len(self.cats)
'''GPU/CPU'''
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
'''日志模块'''
if mode == 'train' or mode == 'eval':
self.logger, self.log_save_path = self.myLogger()
'''训练/验证时参数记录模块'''
json_save_dir, _ = os.path.split(self.log_save_path)
self.argsHistory = ArgsHistory(json_save_dir)
'''实例化tensorboard summaryWriter对象'''
if mode == 'train':
self.tb_writer = SummaryWriter(log_dir=os.path.join(self.log_dir, self.log_save_path.split('.')[0]))
'''导入数据集'''
if mode == 'train':
# 导入训练集
self.train_data = MyDataset(dataset_dir, 'train', imgSize=img_size)
self.train_data_loader = DataLoader(dataset = self.train_data, batch_size=bs, shuffle=True, num_workers=2)
if mode == 'train' or mode == 'eval':
# 导入验证集
self.val_data = MyDataset(dataset_dir, 'valid', imgSize=img_size)
self.val_data_loader = DataLoader(dataset = self.val_data, batch_size=1, shuffle=False, num_workers=2)
'''导入模型'''
self.model = Model(catNums=self.cls_num, modelType=timm_model_name, loadckpt=ckpt_load_path, pretrain=pretrain, froze=froze).to(self.device)
'''定义损失函数(多分类交叉熵损失)'''
if mode == 'train' or mode == 'eval':
self.loss_func = nn.CrossEntropyLoss()
'''定义优化器(自适应学习率的带动量梯度下降方法)'''
if mode == 'train':
self.optimizer, self.scheduler = self.defOptimSheduler()
'''当恢复断点训练'''
self.start_epoch = 0
if resume != None:
checkpoint = torch.load(resume)
self.start_epoch = checkpoint['epoch'] + 1 # +1是因为从当前epoch的下一个epoch开始训练
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# 导入上一次中断训练时的args
json_dir, _ = os.path.split(resume)
self.argsHistory.loadRecord(json_dir)
# 打印日志
if mode == 'train':
self.logger.info('训练集大小: %d' % self.train_data.__len__())
if mode == 'train' or mode == 'eval':
self.logger.info('验证集大小: %d' % self.val_data.__len__())
self.logger.info('数据集类别数: %d' % self.cls_num)
if mode == 'train':
self.logger.info(f'损失函数: {self.loss_func}')
self.logger.info(f'优化器: {self.optimizer}')
if mode == 'train' or mode == 'eval':
self.logger.info(f'全局种子: {seed}')
self.logger.info('='*100)
由于其中有些初始化方法过于冗长,因此封装成为类中的方法:
3.1.1.日志模块初始化
日志模块基于logging
库,会初始化以下内容:
- 定义文件日志(这部分日志会写入日志文件)
- 定义终端日志(这部分日志会打印在终端上)
- 定义文件日志保存路径(根据self.mode的不同而不同)
def myLogger(self):
'''生成日志对象
'''
logger = logging.getLogger('runer')
logger.setLevel(level=logging.DEBUG)
# 日志格式
formatter = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s')
if self.mode == 'train':
# 写入文件的日志
self.log_dir = os.path.join(self.log_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_train")
# 日志文件保存路径
log_save_path = os.path.join(self.log_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_train.log")
if self.mode == 'eval':
# 写入文件的日志
self.log_dir = os.path.join(self.log_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_val")
# 日志文件保存路径
log_save_path = os.path.join(self.log_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_val.log")
if not os.path.isdir(self.log_dir):os.makedirs(self.log_dir)
file_handler = logging.FileHandler(log_save_path, encoding="utf-8", mode="a")
file_handler.setLevel(level=logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 终端输出的日志
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger, log_save_path
3.1.2.自定义变量记录类
ArgsHistory
以iter为最小记录单位,记录train或val过程中的一些变量(比如 loss, acc, lr等),并将记录内容在每个epoch结束以json文件保存。可以方便在训练结束后对这些变量进行可视化。
ArgsHistory.recoard
方法通过传参自动添加新变量,或在已有变量列表的末尾进行更新,无需提前定义变量名。
class ArgsHistory():
'''记录train或val过程中的一些变量(比如 loss, acc, lr等)
'''
def __init__(self, json_save_dir):
self.json_save_dir = json_save_dir
self.args_history_dict = {}
def record(self, key, value):
'''记录args
Args:
:param key: 要记录的当前变量的名字
:param value: 要记录的当前变量的数值
Returns:
None
'''
# 可能存在json格式不支持的类型, 因此统一转成float类型
value = float(value)
# 如果日志中还没有这个变量,则新建
if key not in self.args_history_dict.keys():
self.args_history_dict[key] = []
# 更新
self.args_history_dict[key].append(value)
def saveRecord(self):
'''以json格式保存args
'''
if not os.path.isdir(self.json_save_dir):os.makedirs(self.json_save_dir)
json_save_path = os.path.join(self.json_save_dir, 'args_history.json')
# 保存
with open(json_save_path, 'w') as json_file:
json.dump(self.args_history_dict, json_file)
def loadRecord(self, json_load_dir):
'''导入上一次训练时的args(一般用于resume)
'''
json_path = os.path.join(json_load_dir, 'args_history.json')
with open(json_path, "r", encoding="utf-8") as json_file:
self.args_history_dict = json.load(json_file)
3.1.3.定义优化器和学习率衰减策略
优化器支持pytorch官方提供的sgd
, adam
, adamw
优化器,优化策略基于timm.scheduler.CosineLRScheduler
,采用 warmup+余弦退火。
def defOptimSheduler(self):
'''定义优化器和学习率衰减策略
'''
optimizer = {
# adam会导致weight_decay错误,使用adam时建议设置为 0
'adamw' : torch.optim.AdamW(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=0),
'adam' : torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=0),
'sgd' : torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, nesterov=True, weight_decay=0)
}[self.optim_type]
# 使用warmup+余弦退火学习率
scheduler = CosineLRScheduler(
optimizer=optimizer,
t_initial=self.epoch*len(self.train_data_loader), # 总迭代数
lr_min=self.lr*0.01, # 余弦退火最低的学习率
warmup_t=round(self.epoch/12)*len(self.train_data_loader), # 学习率预热阶段的epoch数量
warmup_lr_init=self.lr*0.01, # 学习率预热阶段的lr起始值
)
return optimizer, scheduler
3.2.训练pipeline trainer
训练pipeline的基本的流程如下:
训练一个epoch(训练集)→在验证集上评估→每个epoch结束保存checkpoint→每个epoch结束打印日志(评估结果)
当所有epoch结束时:重新在验证集上使用最佳权重评估→打印或可视化各种评估指标
def trainer(self):
'''把pytorch训练代码独自分装成一个函数
'''
for epoch in range(self.start_epoch, self.epoch):
'''一个epoch的训练'''
self.fitEpoch(epoch)
'''一个epoch的验证'''
self.evaler(epoch, self.log_dir)
'''保存网络权重'''
self.saveCkpt(epoch)
'''打印日志(一个epoch结束)'''
self.printLog('epoch', 0, epoch, len(self.val_data_loader))
'''结果评估'''
self.model.load_state_dict(torch.load(os.path.join(self.log_dir, 'best.pt')))
# 评估各种指标
self.evaler(self.log_dir)
3.2.1.训练一个epoch的pipeline fitEpoch
主要包括前向反向,更新梯度和学习率, 记录训练时变量,打印日志等步骤。
def fitEpoch(self, epoch):
'''对一个epoch进行训练的流程
'''
self.model.train()
# 一个Epoch包含几轮Batch
train_batch_num = len(self.train_data_loader)
for step, batch in enumerate(self.train_data_loader):
# [bs, channel, w, h] -> [bs, w*h, channel]
with torch.no_grad():
x = batch[0].to(self.device)
y = batch[1].to(self.device).reshape(-1) # 标签[batch_size, 1]
# 前向传播
output = self.model(x) # [batchsize, cls_num]
# 计算loss
loss = self.loss_func(output, y)
# 预测结果对应置信最大的那个下标
pre_lab = torch.argmax(output, 1)
# 计算一个batchsize的准确率
train_acc = torch.sum(pre_lab == y.data) / x.shape[0]
# 记录args(lr, loss, acc)
self.recoardArgs(mode='train', loss=loss.item(), acc=train_acc)
# 记录tensorboard
self.recordTensorboardLog('train', epoch, train_batch_num, step)
# 打印日志
self.printLog('train', step, epoch, train_batch_num)
# 将上一次迭代计算的梯度清零
self.optimizer.zero_grad()
# 反向传播计算梯度
loss.backward()
# 更新参数
self.optimizer.step()
# 更新学习率
self.scheduler.step(epoch * train_batch_num + step)
3.2.1.1.recoardArgs
def recoardArgs(self, mode, loss=None, acc=None, mAP=None, mF1Score=None):
'''训练/验证过程中记录变量(每个iter都会记录, 不间断)
Args:
:param mode: 模式(train, epoch)
:param loss: 损失
:param acc: 准确率
Returns:
None
'''
if mode == 'train':
current_lr = self.optimizer.param_groups[0]['lr']
self.argsHistory.record('lr', current_lr)
self.argsHistory.record('train_loss', loss)
self.argsHistory.record('train_acc', acc)
# 一个epoch结束后val评估结果的平均值
if mode == 'epoch':
self.argsHistory.record('mean_val_acc', acc)
self.argsHistory.record('val_mAP', mAP)
self.argsHistory.record('val_mF1Score', mF1Score)
self.argsHistory.saveRecord()
3.2.1.2.recordTensorboardLog
def recordTensorboardLog(self, mode, epoch, batch_num=None, step=None):
'''训练过程中记录tensorBoard日志
Args:
:param mode: 模式(train, val, epoch)
:param step: 当前迭代到第几个batch
:param batch_num: 当前batch的大小
Returns:
None
'''
if mode == 'train':
step = epoch * batch_num + step
loss = self.argsHistory.args_history_dict['train_loss'][-1]
acc = self.argsHistory.args_history_dict['train_acc'][-1]
self.tb_writer.add_scalar('train_loss', loss, step)
self.tb_writer.add_scalar('train_acc', acc, step)
if mode == 'epoch':
acc = self.argsHistory.args_history_dict['mean_val_acc'][-1]
mAP = self.argsHistory.args_history_dict['val_mAP'][-1]
mF1Score = self.argsHistory.args_history_dict['val_mF1Score'][-1]
self.tb_writer.add_scalar('mean_valid_acc', acc, epoch)
self.tb_writer.add_scalar('valid_mAP', mAP, epoch)
self.tb_writer.add_scalar('valid_mF1Score', mF1Score, epoch)
可视化tensorboard:
3.2.1.3.printLog
def printLog(self, mode, step, epoch, batch_num):
'''训练/验证过程中打印日志
Args:
:param mode: 模式(train, val, epoch)
:param step: 当前迭代到第几个batch
:param epoch: 当前迭代到第几个epoch
:param batch_num: 当前batch的大小
:param loss: 当前batch的loss
:param acc: 当前batch的准确率
:param best_epoch: 当前最佳模型所在的epoch
Returns:
None
'''
lr = self.optimizer.param_groups[0]['lr']
if mode == 'train':
# 每间隔self.log_interval个iter才打印一次
if step % self.log_interval == 0:
loss = self.argsHistory.args_history_dict['train_loss'][-1]
acc = self.argsHistory.args_history_dict['train_acc'][-1]
log = ("Epoch(train) [%d][%d/%d] lr: %8f train_loss: %5f train_acc.: %5f") % (epoch+1, step, batch_num, lr, loss, acc)
self.logger.info(log)
elif mode == 'epoch':
acc_list = self.argsHistory.args_history_dict['mean_val_acc']
mAP_list = self.argsHistory.args_history_dict['val_mAP']
mF1Score_list = self.argsHistory.args_history_dict['val_mF1Score']
# 找到最高准确率对应的epoch
best_epoch = acc_list.index(max(acc_list)) + 1
self.logger.info('=' * 100)
log = ("Epoch [%d] mean_val_acc.: %.5f mAP: %.5f mF1Score: %.5f best_val_epoch: %d" % (epoch+1, acc_list[-1], mAP_list[-1], mF1Score_list[-1], best_epoch))
self.logger.info(log)
self.logger.info('=' * 100)
值得注意的是,3.2.1.1,3.2.1.2 和3.2.1.3的基本逻辑是,先使用recoardArgs
将变量记录到字典argsHistory
中去,后续recordTensorboardLog
和printLog
需要打印或记录哪些变量直接从字典中获取即可,省去了再将变量参数作为函数传参。
3.2.2.评估一个epoch的pipeline evaler
这部分的流程直接调用验证pipeline,固将在相应章节介绍。
3.2.3.保存网络权重 saveCkpt
saveCkpt
在一个epoch结束后保存断点信息(权重、优化器断点,学习率等),并根据验证集acc判断当前epoch是否是最佳权重,是则进行保存。
def saveCkpt(self, epoch):
'''保存权重和训练断点
Args:
:param epoch: 当前epoch
:param max_acc: 当前最佳模型在验证集上的准确率
:param mean_val_acc: 当前epoch准确率
:param best_epoch: 当前最佳模型对应的训练epoch
Returns:
None
'''
# checkpoint_dict能够恢复断点训练
checkpoint_dict = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'sched_state_dict': self.scheduler.state_dict()
}
torch.save(checkpoint_dict, os.path.join(self.log_dir, f"epoch_{epoch}.pt"))
# 如果本次Epoch的acc最大,则保存参数(网络权重)
acc_list = self.argsHistory.args_history_dict['mean_val_acc']
if epoch == acc_list.index(max(acc_list)):
torch.save(self.model.state_dict(), os.path.join(self.log_dir, 'best.pt'))
self.logger.info('best checkpoint has saved !')
3.3.验证pipeline evaler
验证集上完整推理一遍(batch size=1), 并评估各种指标,可视化等。这部分同样作为训练一个epoch结束的评估流程。
def evaler(self, epoch, log_dir):
'''把pytorch训练代码独自分装成一个函数
Args:
:param modelType: 模型名称(timm)
:param DatasetDir: 数据集根目录(到images那一层, 子目录是train/valid)
:param BatchSize: BatchSize
:param imgSize: 网络接受的图像输入尺寸
:param ckptPath: 权重路径
:param logSaveDir: 训练日志保存目录
Returns:
None
'''
# 得到网络预测结果
# shape = [val_size,] [val_size,] [val_size, cls_num]
predList, trueList, softList = self.eval()
'''自定义的实现'''
# 准确率
acc = sum(predList==trueList) / predList.shape[0]
self.logger.info(f'acc: {acc}')
# # 可视化混淆矩阵
showComMatrix(trueList, predList, self.cats, self.log_dir)
# 绘制PR曲线
PRs = drawPRCurve(self.cats, trueList, softList, self.log_dir)
# 计算每个类别的 AP, F1Score
mAP, mF1Score, form = clacAP(PRs, self.cats)
self.logger.info(f'\n{form}')
# 记录args(epoch)
self.recoardArgs(mode='epoch', acc=acc, mAP=mAP, mF1Score=mF1Score)
# 绘制损失,学习率,准确率曲线
visArgsHistory(log_dir, self.log_dir)
# 记录tensorboard的log
if self.mode == 'train':
self.recordTensorboardLog('epoch', epoch)
3.3.1.推理得到网络预测结果eval
eval
方法用于得到真实标签true_list
, 预测标签pred_list
, 置信度soft_list
,为后续计算各种评估指标做准备。
def eval(self):
'''得到网络在验证集的真实标签true_list, 预测标签pred_list, 置信度soft_list, 为后续评估做准备
'''
# 记录真实标签和预测标签
pred_list, true_list, soft_list = [], [], []
# 验证模式
self.model.eval()
# 验证时无需计算梯度
with torch.no_grad():
print('evaluating val dataset...')
for batch in tqdm(self.val_data_loader):
x = batch[0].to(self.device) # [batch_size, 3, 64, 64]
y = batch[1].to(self.device).reshape(-1) # [batch_size, 1]
# 前向传播
output = self.model(x)
# 预测结果对应置信最大的那个下标
pre_lab = torch.argmax(output, dim=1)
# 记录(真实标签true_list, 预测标签pred_list, 置信度soft_list)
true_list += list(y.cpu().detach())
pred_list += list(pre_lab.cpu().detach())
soft_list += list(np.array(output.softmax(dim=-1).cpu().detach()))
return np.array(pred_list), np.array(true_list), np.array(soft_list)
3.3.2.可视化混淆矩阵showComMatrix
def showComMatrix(trueList, predList, cat, evalDir):
'''可视化混淆矩阵
Args:
:param trueList: 验证集的真实标签
:param predList: 网络预测的标签
:param cat: 所有类别的字典
Returns:
None
'''
if len(cat)>=50:
# 100类正合适的大小
plt.figure(figsize=(40, 33))
plt.subplots_adjust(left=0.05, right=1, bottom=0.05, top=0.99)
else:
# 10类正合适的大小
plt.figure(figsize=(12, 9))
plt.subplots_adjust(left=0.1, right=1, bottom=0.1, top=0.99)
conf_mat = confusion_matrix(trueList, predList)
df_cm = pd.DataFrame(conf_mat, index=cat, columns=cat)
heatmap = sns.heatmap(df_cm, annot=True, fmt='d', cmap=plt.cm.Blues)
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation = 0, ha = 'right')
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation = 50, ha = 'right')
plt.ylabel('true label')
plt.xlabel('pred label')
if not os.path.isdir(evalDir):os.makedirs(evalDir)
# 保存图像
plt.savefig(os.path.join(evalDir, '混淆矩阵.png'), dpi=200)
plt.clf()
例(food-101验证集):
3.3.3.绘制类别的PR曲线drawPRCurve
def drawPRCurve(cat, trueList, softList, evalDir):
'''绘制类别的PR曲线
Args:
:param cat: 类别索引列表
:param trueList: 验证集的真实标签
:param softList: 网络预测的置信度
:param evalDir: PR曲线图保存路径
Returns:
None
'''
def calcPRThreshold(trueList, softList, clsNum, T):
'''给定一个类别, 单个阈值下的PR值
Args:
:param trueList: 验证集的真实标签
:param predList: 网络预测的标签
:param clsNum: 类别索引
Returns:
precision, recall
'''
label = (trueList==clsNum)
prob = softList[:,clsNum]>T
TP = sum(label*prob) # 正样本预测为正样本
FN = sum(label*~prob) # 正样本预测为负样本
FP = sum(~label*prob) # 负样本预测为正样本
precision = TP / (TP + FP) if (TP + FP)!=0 else 1
recall = TP / (TP + FN)
return precision, recall, T
def clacPRCurve(trueList, softList, clsNum, interval=100):
'''所有类别下的PR曲线值
Args:
:param trueList: 验证集的真实标签
:param predList: 网络预测的标签
:param clsNum: 类别索引列表
:param interval: 阈值变化划分的区间,如interval=100, 则间隔=0.01
Returns:
:param PRs: 不同阈值下的PR值[2, interval, cat_num]
'''
PRs = []
print('calculating PR per classes...')
for cls in trange(clsNum):
PR_value = [calcPRThreshold(trueList, softList, cls, i/interval) for i in range(interval+1)]
PRs.append(np.array(PR_value))
return np.array(PRs)
plt.figure(figsize=(12, 9))
# 计算所有类别下的PR曲线值
PRs = clacPRCurve(trueList, softList, len(cat))
# 绘制每个类别的PR曲线
for i in range(len(cat)):
PR = PRs[i]
plt.plot(PR[:,1], PR[:,0], linewidth=1)
plt.legend(labels=cat)
plt.xlabel('recall')
plt.ylabel('precision')
plt.xlim(0,1)
plt.ylim(0,1)
# 保存图像
plt.savefig(os.path.join(evalDir, '类别PR曲线.png'), dpi=200)
plt.clf()
return PRs
例(food-101验证集):
3.3.4.计算每个类别的 AP, F1ScoreclacAP
def clacAP(PRs, cat):
'''计算每个类别的 AP, F1Score
Args:
:param PRs: 不同阈值下的PR值[2, interval, cat_num]
:param cat: 类别索引列表
Returns:
None
'''
form = [['catagory', 'AP', 'F1_Score']]
# 所有类别的平均AP与平均F1Score
mAP, mF1Score = 0, 0
for i in range(len(cat)):
PR = PRs[i]
AP = 0
for j in range(PR.shape[0]-1):
# 每小条梯形的矩形部分+三角形部分面积
h = PR[j, 0] - PR[j+1, 0]
w = PR[j, 1] - PR[j+1, 1]
AP += (PR[j+1, 0] * w) + (w * h / 2)
if(PR[j, 2]==0.5):
F1Score0_5 = 2 * PR[j, 0] * PR[j, 1] / (PR[j, 0] + PR[j, 1])
form.append([cat[i], AP, F1Score0_5])
mAP += AP
mF1Score += F1Score0_5
form.append(['average', mAP / len(cat), mF1Score / len(cat)])
return mAP, mF1Score, tabulate(form, headers='firstrow') # tablefmt='fancy_grid'
例,输出的逐类别评估指标(food-101验证集):
catagory AP F1_Score
----------------------- -------- ----------
apple_pie 0.634729 0.596413
baby_back_ribs 0.871711 0.815574
baklava 0.924247 0.870445
beef_carpaccio 0.920737 0.877551
beef_tartare 0.853179 0.806794
beet_salad 0.791348 0.746835
beignets 0.929849 0.858824
bibimbap 0.970363 0.931452
bread_pudding 0.628122 0.598778
breakfast_burrito 0.816585 0.759494
bruschetta 0.811589 0.74645
caesar_salad 0.92251 0.862205
cannoli 0.912558 0.859504
caprese_salad 0.883566 0.81409
carrot_cake 0.853388 0.7833
ceviche 0.742306 0.693446
cheesecake 0.91843 0.875
cheese_plate 0.775621 0.711579
chicken_curry 0.865944 0.805785
chicken_quesadilla 0.877484 0.830957
chicken_wings 0.935781 0.883534
chocolate_cake 0.713896 0.676596
chocolate_mousse 0.644484 0.616302
churros 0.953728 0.917505
clam_chowder 0.928124 0.879032
club_sandwich 0.919055 0.866935
crab_cakes 0.81353 0.78
creme_brulee 0.9357 0.901354
croque_madame 0.923631 0.878049
cup_cakes 0.933192 0.886719
deviled_eggs 0.951984 0.927126
donuts 0.894698 0.843177
dumplings 0.935255 0.903491
edamame 0.997601 0.993964
eggs_benedict 0.926735 0.88755
escargots 0.938622 0.904
falafel 0.855487 0.797495
filet_mignon 0.716486 0.666667
fish_and_chips 0.917366 0.866667
foie_gras 0.706522 0.655804
french_fries 0.957856 0.912621
french_onion_soup 0.902586 0.853175
french_toast 0.830571 0.789062
fried_calamari 0.914943 0.878543
fried_rice 0.909526 0.846602
frozen_yogurt 0.961339 0.91945
garlic_bread 0.853865 0.803245
gnocchi 0.811815 0.745174
greek_salad 0.907218 0.847737
grilled_cheese_sandwich 0.802601 0.746507
grilled_salmon 0.823634 0.762887
guacamole 0.932105 0.893443
gyoza 0.914981 0.882353
hamburger 0.872562 0.80167
hot_and_sour_soup 0.965024 0.930693
hot_dog 0.901362 0.856
huevos_rancheros 0.784487 0.716484
hummus 0.895116 0.847107
ice_cream 0.828925 0.763948
lasagna 0.820834 0.774059
lobster_bisque 0.90727 0.858871
lobster_roll_sandwich 0.94466 0.907975
macaroni_and_cheese 0.891181 0.829569
macarons 0.979871 0.95122
miso_soup 0.970756 0.918489
mussels 0.956927 0.919918
nachos 0.880572 0.830266
omelette 0.783796 0.713656
onion_rings 0.947257 0.913725
oysters 0.958218 0.928287
pad_thai 0.957988 0.894027
paella 0.913414 0.854839
pancakes 0.91128 0.866935
panna_cotta 0.798689 0.743434
peking_duck 0.925841 0.864097
pho 0.967159 0.938614
pizza 0.932811 0.870406
pork_chop 0.651566 0.606557
poutine 0.960185 0.918367
prime_rib 0.891828 0.850394
pulled_pork_sandwich 0.876265 0.812245
ramen 0.936075 0.873016
ravioli 0.752236 0.698545
red_velvet_cake 0.898298 0.860041
risotto 0.797334 0.739394
samosa 0.866941 0.8375
sashimi 0.938466 0.899384
scallops 0.783607 0.732919
seaweed_salad 0.957724 0.917172
shrimp_and_grits 0.833112 0.771084
spaghetti_bolognese 0.956515 0.916
spaghetti_carbonara 0.960952 0.900398
spring_rolls 0.885243 0.852459
steak 0.570503 0.524313
strawberry_shortcake 0.859055 0.805726
sushi 0.891105 0.846626
tacos 0.82507 0.771037
takoyaki 0.958959 0.92623
tiramisu 0.845881 0.792079
tuna_tartare 0.75262 0.699411
waffles 0.906239 0.854291
average 0.873475 0.825314
3.3.4.可视化训练过程中保存的参数visArgsHistory
def visArgsHistory(json_dir, save_dir):
'''可视化训练过程中保存的参数
Args:
:param json_dir: 参数的json文件路径
:param logDir: 可视化json文件保存路径
Returns:
None
'''
json_path = os.path.join(json_dir, 'args_history.json')
with open(json_path) as json_file:
args = json.load(json_file)
for args_key in args.keys():
arg = args[args_key]
plt.plot(arg, linewidth=1)
plt.xlabel('Epoch')
plt.ylabel(args_key)
plt.savefig(os.path.join(save_dir, f'{args_key}.png'), dpi=200)
plt.clf()
例(food-101,mobilenet-v3-large):
3.4.测试pipeline tester
可视化图像CAM热图和分类结果(top10)
def tester(self, img_path, save_res_dir):
'''把pytorch测试代码独自分装成一个函数
Args:
:param img_path: 测试图像路径
:param save_res_dir: 推理结果保存目录
Returns:
None
'''
from dataloader import Transforms
# 加载一张图片并进行预处理
image = Image.open(img_path)
image = np.array(image)
tf = Transforms(imgSize = self.img_size)
visImg = tf.visTF(image=image)['image']
img = torch.tensor(tf.validTF(image=image)['image']).permute(2,1,0).unsqueeze(0).to(self.device)
# 加载网络
self.model.eval()
# 预测
logits = self.model(img).softmax(dim=-1).cpu().detach().numpy()[0]
sorted_id = sorted(range(len(logits)), key=lambda k: logits[k], reverse=True)
# 超过10类则只显示top10的类别
logits_top_10 = logits[sorted_id[:10]]
cats_top_10 = [self.cats[i] for i in sorted_id[:10]]
'''CAM'''
# CAM需要网络能反传梯度, 否则会报错
# 要可视化网络哪一层的CAM(以mobilenetv3_large_100.ra_in1k为例, 不同的网络这部分还需更改)
target_layers = [self.model.backbone.blocks[-1]]
cam = GradCAM(model=self.model, target_layers=target_layers)
# 要关注的区域对应的类别
targets = [ClassifierOutputTarget(sorted_id[0])]
grayscale_cam = cam(input_tensor=img, targets=targets)[0].transpose(1,0)
visualization = show_cam_on_image(visImg / 255., grayscale_cam, use_rgb=True)
'''可视化预测结果'''
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
# 在第一个子图中绘制图像
ax1.set_title('image')
ax1.axis('off')
# ax1.imshow(image)
ax1.imshow(visualization)
# 在第二个子图中绘制置信度(横向)
ax2.barh(cats_top_10, logits_top_10.reshape(-1))
ax2.set_title('classification')
ax2.set_xlabel('confidence')
# 将数值最大的条块设置为不同颜色
bar2 = ax2.patches[0]
bar2.set_color('orange')
# y轴上下反转,不然概率最大的在最下面
plt.gca().invert_yaxis()
plt.subplots_adjust(left=0.05, right=0.99, bottom=0.1, top=0.90)
if not os.path.isdir(save_res_dir):os.makedirs(save_res_dir)
plt.savefig(os.path.join(save_res_dir, 'res.jpg'), dpi=200)
plt.clf()
例:
3.5.其他
主函数部分,通过命令行参数获取config文件路径,读取config文件里的参数(字典形式)作为训练的超参。
if __name__ == '__main__':
args = getArgs()
# 使用动态导入的模块
config_path = args.config
config_file = import_module_by_path(config_path)
# 调用动态导入的模块的函数
config = config_file.config
runner = Runner(config['timm_model_name'], config['img_size'], config['ckpt_load_path'],config['dataset_dir'], config['epoch'], config['bs'], config['lr'],
config['log_dir'], config['log_interval'], config['pretrain'], config['froze'], config['optim_type'], config['mode'], config['resume'], config['seed'])
# 训练
if config['mode'] == 'train':
runner.trainer()
# 评估
elif config['mode'] == 'eval':
runner.evaler(epoch=0, log_dir=config['eval_log_dir'])
elif config['mode'] == 'test':
runner.tester(config['img_path'], config['save_res_dir'])
else:
print("mode not valid. it must be 'train', 'eval' or 'test'.")
3.5.1 获取命令行参数
def getArgs():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='config file')
args = parser.parse_args()
return args
3.5.2 根据给定路径动态import模块(config.py)
def import_module_by_path(module_path):
"""根据给定的完整路径动态导入模块(config.py)
"""
spec = importlib.util.spec_from_file_location("module_name", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
3.5.3 config.py样例
config = dict(
# train
mode = 'test',
timm_model_name = 'mobilenetv3_large_100.ra_in1k',
img_size = 224,
ckpt_load_path = 'log/2024-02-14-03-04-03_train/best.pt',
dataset_dir = 'E:/datasets/Classification/food-101/images',
epoch = 36,
bs = 64,
lr = 1e-3,
log_dir = './log',
log_interval = 50,
pretrain = True,
froze = False,
optim_type = 'adamw',
resume = None, # 'log/2024-02-05-21-28-59_train/epoch_9.pt',
seed=22,
# eval
eval_log_dir = 'log/2024-02-14-03-04-03_train',
# test
# french_fries/3171053.jpg 3897130.jpg 3393816.jpg club_sandwich/3143042.jpg
img_path = 'E:/datasets/Classification/food-101/images/valid/french_fries/3393816.jpg',
save_res_dir = './result'
)