加载预训练模型计算测试数据集的LogLoss、AUC和EER,需要根据具体场景选择相应的计算方法。以下是三种常见的方法:
1. 计算LogLoss:
```python
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
# 加载模型和测试数据集
model = torch.load('pretrained_model.pth')
test_data = YourTestData(...)
test_loader = DataLoader(test_data, batch_size=64)
# 计算测试数据集的LogLoss
model.eval()
test_loss = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += F.nll_loss(output, target).item() # 累加每个批次的loss
test_loss /= len(test_loader) # loss取平均值
print('Test set: Average Loss: {:.4f}'.format(test_loss))
上述代码中,使用PyTorch提供的F.nll_loss()
函数来计算测试数据集的LogLoss。在计算时需要累加每个批次的loss,最后将总和除以测试集的样本数就可以得到平均Loss。
- 计算AUC:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
# 加载模型和测试数据集
model = torch.load('pretrained_model.pth')
test_data = YourTestData(...)
test_loader = DataLoader(test_data, batch_size=64)
# 计算测试数据集的预测概率和真实标签
model.eval()
y_true, y_score = [], []
with torch.no_grad():
for data, target in test_loader:
output = torch.sigmoid(model(data)) # 使用sigmoid转换到0~1之间的概率
y_true.extend(target.tolist()) # 将真实标签添加到列表中
y_score.extend(output.tolist()) # 将预测概率添加到列表中
# 计算AUC
auc = roc_auc_score(y_true, y_score)
print('Test set: AUC = {:.4f}'.format(auc))
上述代码中,使用sklearn.metrics提供的roc_auc_score()
函数来计算测试数据集的AUC。在计算时需要获取每个样本的预测概率和真实标签,可以使用torch.sigmoid()
将输出转换到0~1之间的概率,然后将它们添加到两个列表中。最后使用roc_auc_score()
函数来计算AUC。
- 计算EER:
import torch
import numpy as np
from scipy.optimize import brentq
from sklearn.metrics import roc_curve
def calculate_eer(y_true, y_score):
fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
eer = brentq(lambda x: 1. - x - np.interp(x, fpr, tpr), 0., 1.)
return eer * 100
# 加载模型和测试数据集
model = torch.load('pretrained_model.pth')
test_data = YourTestData(...)
test_loader = DataLoader(test_data, batch_size=64)
# 计算测试数据集的预测概率和真实标签
model.eval()
y_true, y_score = [], []
with torch.no_grad():
for data, target in test_loader:
output = torch.sigmoid(model(data)) # 使用sigmoid转换到0~1之间的概率
y_true.extend(target.tolist()) # 将真实标签添加到列表中
y_score.extend(output.tolist()) # 将预测概率添加到列表中
# 计算EER
eer = calculate_eer(y_true, y_score)
print('Test set: EER = {:.2f}%'.format(eer))
上述代码中,自定义了一个calculate_eer()
函数来计算测试数据集的EER。首先使用sklearn.metrics提供的roc_curve()
函数来计算FPR和TPR,并获得阈值。然后使用scipy.optimize提供的brentq()
函数来求解ERR对应的阈值,最后将求解结果乘以100,即可得到EER的百分比形式。
``