文章目录
- 一、安装配置环境
- 二、准备图像分类数据集
- 三、迁移学习微调训练图像分类模型—基础版
- 四、可视化训练日志
一、安装配置环境
1. 下载相关工具&文件
# 下载各种python库
!pip install numpy pandas matplotlib seaborn plotly requests tqdm opencv-python pillow wandb -i https://pypi.tuna.tsinghua.edu.cn/simple
# 下载安装pytorch
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# 下载中文字体
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf --no-check-certificate
2. 创建目录
import os
# 存放结果文件
os.mkdir('output')
# 存放训练得到的模型权重
os.mkdir('checkpoints')
# 存放生成的图表
os.mkdir('图表')
3. 设置matplotlib中文字体
import matplotlib.pyplot as plt
%matplotlib inline
# windows操作系统
plt.rcParams['font.sans-serif']=['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号
# 测试是否设置成功
plt.plot([1,2,3], [100,500,300])
plt.title('matplotlib中文字体测试', fontsize=25)
plt.xlabel('X轴', fontsize=15)
plt.ylabel('Y轴', fontsize=15)
plt.show()
二、准备图像分类数据集
可以使用第一个task中自己构建的分类数据集,这里为了便于复现选择使用教程给的现成的数据集:
# 下载数据集压缩包
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/fruit30_split.zip
# 解压
!unzip fruit30_split.zip
# 删除压缩包
!rm fruit30_split.zip
# 查看数据集目录结构
!sudo snap install tree
!tree fruit30_split -L 2
三、迁移学习微调训练图像分类模型—基础版
1. 设置matplotlib中文字体、导入工具包
import matplotlib as plt
# windows操作系统
plt.rcParams['font.sans-serif']=['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号
import time
import os
import numpy as np
from tqdm import tqdm
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")
2. 获取计算机硬件
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)
3. 定义图像预处理方法(训练集、测试集)
from torchvision import transforms
# 训练集图像预处理:缩放、裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
第一个代码块 train_transform
中包括了以下图像预处理步骤:
-
transforms.RandomResizedCrop(224)
随机裁剪一个大小为 224x224 的图像区域,该操作在训练过程中有助于数据增强和防止过拟合。 -
transforms.RandomHorizontalFlip()
随机水平翻转图像,该操作也是为了数据增强。 -
transforms.ToTensor()
将 PIL 图像转换为 PyTorch 张量,这是神经网络输入数据的常见格式。将图像转换为一个三维数组,第一个维度表示通道数(如果是灰度图像,则通道数为1,如果是 RGB 图像,则通道数为3),第二个维度和第三个维度分别表示图像的高和宽。
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
对输入张量进行归一化处理,这里的参数是每个通道的均值和标准差,用于将输入数据缩放到固定范围,使模型更容易训练。
第二个代码块 test_transform
中包括了以下图像预处理步骤:
transforms.Resize(256)
将图像的短边缩放为 256 像素,保持图像的长宽比。transforms.CenterCrop(224)
从图像中心裁剪一个 224x224 的图像区域。transforms.ToTensor()
将 PIL 图像转换为 PyTorch 张量。transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
对输入张量进行归一化处理,这里的参数与第一个代码块中的参数相同,用于将输入数据缩放到固定范围,使模型更容易训练。
4. 载入图像分类数据集
# 数据集文件夹路径
dataset_dir = 'fruit30_split'
# 通过路径拼接得到训练/测试数据集完整路径
train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)
from torchvision import datasets
# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
这样我们就得到了经过预处理后的训练集和测试集。
5. 类别和索引号一一对应(映射字典)
# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)
# 映射关系:类别 到 索引号
train_dataset.class_to_idx
# 映射关系:索引号 到 类别
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}
# 保存为本地的 npy 文件,以便在模型训练过程中进行读取
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)
6. 定义数据加载器DataLoader
from torch.utils.data import DataLoader
BATCH_SIZE = 32
# 训练集的数据加载器
train_loader = DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4
)
# 测试集的数据加载器
test_loader = DataLoader(test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=4
)
以上代码使用PyTorch的DataLoader
类定义了用于训练和测试数据集的数据加载器。
-
DataLoader
是PyTorch的一个类,它提供了一个可迭代对象,支持批处理、随机打乱和多进程数据加载。它可以用于并行从磁盘加载数据,这可以帮助加速训练过程。 -
这段代码为训练和测试数据加载器都指定了批大小为32,这意味着在训练和测试过程中,数据将按32个样本为一批进行处理。
-
shuffle
参数对于训练数据加载器被设置为True
,对于测试数据加载器被设置为False
。这意味着训练数据将在每个epoch(完整数据集的一轮)之间随机打乱,这有助于防止模型过度拟合训练数据的顺序。测试数据不被打乱,因为我们想要评估模型在实际数据分布上的性能。 -
num_workers
参数指定用于并行加载数据的子进程数。在本例中,它设置为4,这意味着数据将使用4个子进程加载。这可以帮助加速数据加载过程,特别是在处理大型数据集时。
7. 查看&可视化一个batch的图像和标注
# DataLoader 是 python生成器,每次调用返回一个 batch 的数据
images, labels = next(iter(train_loader))
images.shape
:
labels
:
# 将数据集中的Tensor张量转为numpy的array数据类型
images = images.numpy()
# 可视化第6张图片前50个像素值的直方图
plt.hist(images[5].flatten(), bins=50)
plt.show()
# batch 中经过预处理的图像
idx = 2
plt.imshow(images[idx].transpose((1,2,0))) # 转为(224, 224, 3)
plt.title('label:'+str(labels[idx].item()))
这段代码显示了一个经过预处理的图像。images
是一个batch中的图像张量,labels
是其对应的标签张量。通过transpose
函数,将图像张量的维度从(3, 224, 224)
变为(224, 224, 3)
,以便用于matplotlib库的显示。同时,代码中还在图像标题上显示了这张图像的标签。其中,item()
函数用于获取标签张量中的整数值。
# 原始图像
idx = 2
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
plt.imshow(np.clip(images[idx].transpose((1,2,0)) * std + mean, 0, 1))
plt.title('label:'+ pred_classname)
plt.show()
训练数据集中对图像进行归一化的逆操作得到各像素点的原像素值。
8. 导入训练需使用的工具包
from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduler
9. 选择迁移学习训练方式(这里暂选第一个)
选择1:只微调训练模型最后一层(全连接分类层)
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True
model.fc = nn.Linear(model.fc.in_features, n_class)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())
这段代码使用了在ImageNet上预训练好的ResNet-18模型,并将其最后一层全连接层替换成一个新的全连接层,该全连接层的输出节点数与当前数据集类别数对应。此外,这段代码还将只微调最后一层全连接层的参数,其他层的参数冻结,以免在微调的过程中影响在ImageNet上预训练好的权重。最后,使用 Adam优化器 来训练模型的最后一层全连接层的参数。
nn.Linear(model.fc.in_features, n_class)
创建了一个输入大小为model.fc.in_features
,输出大小为n_class
的全连接层。
model.fc
代表resnet18
模型的最后一层全连接层。model.fc.in_features
表示全连接层的输入特征数。默认情况下,model.fc
的输出大小是1000,因为该预训练模型是在1000个类别的ImageNet数据集上进行训练的。而在当前数据集中,类别数为30。因此,我们需要将全连接层的输出大小修改为30,以便进行分类。nn.Linear()
是PyTorch中的一个函数,用于构建一个全连接层。第一个参数为输入特征数,即上一层的输出大小;第二个参数n_class
是输出特征数,即当前数据集中的类别数。
选择2:微调训练所有层
model = models.resnet18(pretrained=True) # 载入预训练模型
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
选择3:随机初始化模型全部权重,从头训练所有层
model = models.resnet18(pretrained=False) # 只载入模型结构,不载入预训练权重参数
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())
10. 训练配置
model = model.to(device)
# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()
# 训练轮次 Epoch
EPOCHS = 30
# 学习率降低策略
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
这个学习率降低策略使用了StepLR
调度程序,它将每个epoch的学习率降低为原来的一半。具体来说,每5个epoch,学习率将乘以0.5。optimizer
是优化器对象,step_size
参数指定多少个epoch后将学习率降低,gamma
参数指定学习率降低的倍数。
11. 函数:在训练集上一个batch的训练
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
def train_one_batch(images, labels):
'''
运行一个 batch 的训练,返回当前 batch 的训练日志
'''
# 获得一个 batch 的数据和标注
images = images.to(device)
labels = labels.to(device)
outputs = model(images) # 输入模型,执行前向预测获得当前 batch 所有图像的预测类别 logit 分数
loss = criterion(outputs, labels) # 计算当前 batch 中,每个样本的平均交叉熵损失函数值
# 优化更新权重
optimizer.zero_grad() # 清除梯度
loss.backward() # 反向传播
optimizer.step() # 优化更新
# 获取当前 batch 的标签类别和预测类别
_, preds = torch.max(outputs, 1) # 获得当前 batch 所有图像的预测类别
# 从GPU内存转移到CPU内存,并转换为NumPy数组类型,方便后续的计算和处理
preds = preds.cpu().numpy()
loss = loss.detach().cpu().numpy()
outputs = outputs.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
log_train = {}
log_train['epoch'] = epoch
log_train['batch'] = batch_idx
# 计算分类评估指标
log_train['train_loss'] = loss
log_train['train_accuracy'] = accuracy_score(labels, preds)
# log_train['train_precision'] = precision_score(labels, preds, average='macro')
# log_train['train_recall'] = recall_score(labels, preds, average='macro')
# log_train['train_f1-score'] = f1_score(labels, preds, average='macro')
return log_train
preds
:
labels
:
注:这里预测结果和真实lable基本上对不上是因为这个时候还没训练好模型,只是模拟了一个batch的训练而已。
主要步骤:
-
从训练集数据加载器中获取一个batch的图像数据和对应的标注,将其传入GPU进行计算。
-
输入模型,执行前向预测,得到输出结果
outputs
。outputs
是一个形状为(batch_size, n_class)
的张量,其中batch_size
表示当前batch的大小,n_class
表示分类的类别数。 -
计算当前batch中,所有样本的平均交叉熵损失函数值,即将
outputs
和labels
传入交叉熵损失函数中,得到损失函数值loss
。交叉熵损失函数:
L ( y , y ^ ) = − ∑ i = 1 C y i log y i ^ L(y,\hat{y})=-\sum_{i=1}^{C}y_{i}\log\hat{y_{i}} L(y,y^)=−i=1∑Cyilogyi^
其中, y y y 是真实标签的 one-hot 向量, y ^ \hat{y} y^ 是模型预测的类别概率向量, C C C 是类别数量。 -
执行反向传播,即通过计算
loss
对模型参数的梯度,然后更新模型参数。- 清除梯度:在反向传播前,需要清除上一次计算的梯度,因为PyTorch默认是累加梯度的。
- 反向传播:使用PyTorch中的
loss.backward()
函数进行反向传播,计算损失函数对每个参数的梯度。 - 优化更新:通过优化器中的
optimizer.step()
函数更新网络参数,使损失函数值更小。
-
计算当前batch中,所有样本的预测类别,即将
outputs
的每一行求最大值所在的位置作为预测类别,用torch.max()
函数实现,返回的是一个元组(max_value, max_index)
,max_index
即为预测类别。 -
将 preds、loss、outputs、lables 转换为numpy数组以保存到损失列表中。通过使用numpy数组,可以避免将张量转移到CPU内存并节省显存。
-
计算分类评估指标,包括损失、准确度等,并将它们保存到字典中,作为评估指标的日志,最后返回这个字典。
12. 函数:在整个测试集上评估
def evaluate_testset():
'''
在整个测试集上评估,返回分类评估指标日志
'''
loss_list = []
labels_list = []
preds_list = []
with torch.no_grad():
for images, labels in test_loader: # 生成一个 batch 的数据和标注
images = images.to(device)
labels = labels.to(device)
outputs = model(images) # 输入模型,执行前向预测
# 获取整个测试集的标签类别和预测类别
_, preds = torch.max(outputs, 1) # 获得当前 batch 所有图像的预测类别
preds = preds.cpu().numpy()
loss = criterion(outputs, labels) # 由 logit,计算当前 batch 中,每个样本的平均交叉熵损失函数值
loss = loss.detach().cpu().numpy()
outputs = outputs.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
loss_list.append(loss)
labels_list.extend(labels)
preds_list.extend(preds)
log_test = {}
log_test['epoch'] = epoch
# 计算分类评估指标
log_test['test_loss'] = np.mean(loss)
log_test['test_accuracy'] = accuracy_score(labels_list, preds_list)
log_test['test_precision'] = precision_score(labels_list, preds_list, average='macro')
log_test['test_recall'] = recall_score(labels_list, preds_list, average='macro')
log_test['test_f1-score'] = f1_score(labels_list, preds_list, average='macro')
return log_test
-
定义了三个空列表,用于记录损失、标签和预测结果。
-
然后使用
with torch.no_grad()
上下文管理器,关闭梯度计算,避免在评估过程中消耗过多的显存。 -
使用测试集迭代器生成每个测试batch的数据和标注,将其发送到GPU并将其输入到模型中进行前向预测。
-
然后计算当前batch中,每个样本的平均交叉熵损失函数值,并将其转换为numpy数组以保存到损失列表中。通过使用numpy数组,可以避免将张量转移到CPU内存并节省显存。
-
计算分类评估指标,包括损失、准确度、精确度、召回率和F1分数,并将它们保存到字典中,作为评估指标的日志,最后返回这个字典。
-
准确率(Accuracy):表示分类器正确分类的样本占总样本数的比例。
Accuracy = 正确分类的样本数 总样本数 \text{Accuracy} = \frac{\text{正确分类的样本数}}{\text{总样本数}} Accuracy=总样本数正确分类的样本数 -
精确率(Precision):表示分类器正确分类的正样本占预测为正样本的比例。
Precision = 真正例(TP) 真正例(TP) + 假正例(FP) \text{Precision} = \frac{\text{真正例(TP)}}{\text{真正例(TP)}+\text{假正例(FP)}} Precision=真正例(TP)+假正例(FP)真正例(TP) -
召回率(Recall):表示分类器正确分类的正样本占实际为正样本的比例。
Recall = 真正例(TP) 真正例(TP) + 假负例(FN) \text{Recall} = \frac{\text{真正例(TP)}}{\text{真正例(TP)}+\text{假负例(FN)}} Recall=真正例(TP)+假负例(FN)真正例(TP) -
F1-score:精确率和召回率的加权平均值,综合反映分类器的性能。
F1-score = 2 × Precision × Recall Precision + Recall \text{F1-score} = \frac{2 \times \text{Precision} \times \text{Recall}}{\text{Precision}+\text{Recall}} F1-score=Precision+Recall2×Precision×Recall
其中,TP表示真正例(True Positive),即实际为正样本且被正确分类为正样本的样本数;FP表示假正例(False Positive),即实际为负样本但被错误分类为正样本的样本数;FN表示假负例(False Negative),即实际为正样本但被错误分类为负样本的样本数。
-
13. 训练开始之前,记录日志
epoch = 0
batch_idx = 0
best_test_accuracy = 0
# 训练日志-训练集
df_train_log = pd.DataFrame()
log_train = {}
log_train['epoch'] = 0
log_train['batch'] = 0
images, labels = next(iter(train_loader))
log_train.update(train_one_batch(images, labels))
df_train_log = df_train_log.append(log_train, ignore_index=True)
# 训练日志-测试集
df_test_log = pd.DataFrame()
log_test = {}
log_test['epoch'] = 0
log_test.update(evaluate_testset())
df_test_log = df_test_log.append(log_test, ignore_index=True)
df_train_log
:
df_test_log
:
14. 登录wandb&创建wandb可视化项目
-
安装 wandb:pip install wandb
-
登录 wandb:在命令行中运行wandb login
-
按提示复制粘贴API Key至命令行中
import wandb
wandb.init(project='fruit30', name=time.strftime('%m%d%H%M%S'))
15. 运行训练
for epoch in range(1, EPOCHS+1):
print(f'Epoch {epoch}/{EPOCHS}')
## 训练阶段
model.train()
for images, labels in tqdm(train_loader): # 获得一个 batch 的数据和标注
batch_idx += 1
log_train = train_one_batch(images, labels)
df_train_log = df_train_log.append(log_train, ignore_index=True)
wandb.log(log_train)
lr_scheduler.step()
## 测试阶段
model.eval()
log_test = evaluate_testset()
df_test_log = df_test_log.append(log_test, ignore_index=True)
wandb.log(log_test)
# 保存最新的最佳模型文件
if log_test['test_accuracy'] > best_test_accuracy:
# 删除旧的最佳模型文件(如有)
old_best_checkpoint_path = 'checkpoints/best-{:.3f}.pth'.format(best_test_accuracy)
if os.path.exists(old_best_checkpoint_path):
os.remove(old_best_checkpoint_path)
# 保存新的最佳模型文件
new_best_checkpoint_path = 'checkpoints/best-{:.3f}.pth'.format(log_test['test_accuracy'])
torch.save(model, new_best_checkpoint_path)
print('保存新的最佳模型', 'checkpoints/best-{:.3f}.pth'.format(best_test_accuracy))
best_test_accuracy = log_test['test_accuracy']
df_train_log.to_csv('训练日志-训练集.csv', index=False)
df_test_log.to_csv('训练日志-测试集.csv', index=False)
下面是训练过程的训练日志:
16. 在测试集上评估
# 载入最佳模型作为当前模型
model = torch.load('checkpoints/best-{:.3f}.pth'.format(best_test_accuracy))
model.eval()
print(evaluate_testset())
四、可视化训练日志
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
# windows操作系统设置matplotlib中文字体
plt.rcParams['font.sans-serif']=['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号
1. 载入训练日志表格
df_train = pd.read_csv('训练日志-训练集.csv')
df_test = pd.read_csv('训练日志-测试集.csv')
df_train
:
df_test
:
2. 训练集损失函数
plt.figure(figsize=(16, 8))
x = df_train['batch']
y = df_train['train_loss']
plt.plot(x, y, label='训练集')
plt.tick_params(labelsize=20)
plt.xlabel('batch', fontsize=20)
plt.ylabel('loss', fontsize=20)
plt.title('训练集损失函数', fontsize=25)
plt.savefig('图表/训练集损失函数.pdf', dpi=120, bbox_inches='tight')
plt.show()
3. 训练集准确率
plt.figure(figsize=(16, 8))
x = df_train['batch']
y = df_train['train_accuracy']
plt.plot(x, y, label='训练集')
plt.tick_params(labelsize=20)
plt.xlabel('batch', fontsize=20)
plt.ylabel('loss', fontsize=20)
plt.title('训练集准确率', fontsize=25)
plt.savefig('图表/训练集准确率.pdf', dpi=120, bbox_inches='tight')
plt.show()
4. 测试集损失函数
plt.figure(figsize=(16, 8))
x = df_test['epoch']
y = df_test['test_loss']
plt.plot(x, y, label='测试集')
plt.tick_params(labelsize=20)
plt.xlabel('epoch', fontsize=20)
plt.ylabel('loss', fontsize=20)
plt.title('测试集损失函数', fontsize=25)
plt.savefig('图表/测试集损失函数.pdf', dpi=120, bbox_inches='tight')
plt.show()
5. 测试集评估指标
from matplotlib import colors as mcolors
import random
random.seed(124)
colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan', 'black', 'indianred', 'brown', 'firebrick', 'maroon', 'darkred', 'red', 'sienna', 'chocolate', 'yellow', 'olivedrab', 'yellowgreen', 'darkolivegreen', 'forestgreen', 'limegreen', 'darkgreen', 'green', 'lime', 'seagreen', 'mediumseagreen', 'darkslategray', 'darkslategrey', 'teal', 'darkcyan', 'dodgerblue', 'navy', 'darkblue', 'mediumblue', 'blue', 'slateblue', 'darkslateblue', 'mediumslateblue', 'mediumpurple', 'rebeccapurple', 'blueviolet', 'indigo', 'darkorchid', 'darkviolet', 'mediumorchid', 'purple', 'darkmagenta', 'fuchsia', 'magenta', 'orchid', 'mediumvioletred', 'deeppink', 'hotpink']
markers = [".",",","o","v","^","<",">","1","2","3","4","8","s","p","P","*","h","H","+","x","X","D","d","|","_",0,1,2,3,4,5,6,7,8,9,10,11]
linestyle = ['--', '-.', '-']
def get_line_arg():
'''
随机产生一种绘图线型
'''
line_arg = {}
line_arg['color'] = random.choice(colors)
# line_arg['marker'] = random.choice(markers)
line_arg['linestyle'] = random.choice(linestyle)
line_arg['linewidth'] = random.randint(1, 4)
# line_arg['markersize'] = random.randint(3, 5)
return line_arg
metrics = ['test_accuracy', 'test_precision', 'test_recall', 'test_f1-score']
plt.figure(figsize=(16, 8))
x = df_test['epoch']
for y in metrics:
plt.plot(x, df_test[y], label=y, **get_line_arg())
plt.tick_params(labelsize=20)
plt.ylim([0, 1])
plt.xlabel('epoch', fontsize=20)
plt.ylabel(y, fontsize=20)
plt.title('测试集分类评估指标', fontsize=25)
plt.savefig('图表/测试集分类评估指标.pdf', dpi=120, bbox_inches='tight')
plt.legend(fontsize=20)
plt.show()