加载模型权重
model.load_state_dict(
torch.load('epoch_3_valid_macrof1_95.878_microf1_96.049_weights.bin', map_location=torch.device('cpu'))
)
model.eval()
-
功能:
- 从文件
epoch_3_valid_macrof1_95.878_microf1_96.049_weights.bin
加载训练好的模型权重。 - 将模型切换到评估模式(
eval()
),关闭 Dropout 和 BatchNorm 等训练专用功能。
- 从文件
-
意义:
- 保证模型可以在测试集上进行稳定推理,输出预测结果。
map_location=torch.device('cpu')
参数作用
map_location
是torch.load()
函数的一个参数,用于指定加载权重时,将权重加载到哪种设备(如 CPU 或 GPU)。
背景
- PyTorch 模型权重文件可能是在训练时保存的(通常位于 GPU 上),而加载时,设备可能不同。例如:
- 权重在 GPU 上保存,而加载是在 CPU 环境。
- 权重在一块 GPU 上保存,而加载时需要移动到另一块 GPU。
model.eval()
作用
-
将模型切换到 评估模式(evaluation mode)。
-
评估模式 是 PyTorch 的模型工作模式之一,与训练模式(
train()
)相对。
具体功能
-
关闭训练专用的操作:
-
关闭 Dropout:确保推理时不随机丢弃神经元。
-
固定 Batch Normalization(BN):推理时使用固定的均值和方差,而不是实时更新。
-
-
稳定推理结果:
-
评估模式确保结果具有确定性,不受训练相关操作的影响。
-
其他模式
PyTorch 中,模型主要有以下两种模式:
模式 | 方法 | 功能 |
---|---|---|
训练模式 | model.train() | - 开启 Dropout 和 Batch Normalization。 - 模型处于训练状态,可以进行梯度更新。 |
评估模式 | model.eval() | - 关闭 Dropout 和实时 Batch Normalization。 - 模型处于推理状态,不进行训练。 |
模型评估
with torch.no_grad():
print('evaluating on test set...')
true_labels, true_predictions = [], []
for X, y in tqdm(test_dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
predictions = pred.argmax(dim=-1).cpu().numpy().tolist()
labels = y.cpu().numpy().tolist()
true_labels += [[id2label[int(l)] for l in label if l != -100] for label in labels]
true_predictions += [
[id2label[int(p)] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
print(classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2))
-
功能:
- 使用测试集的 DataLoader(
test_dataloader
)加载测试数据。 - 对每个批次的数据(
X, y
)进行推理,获得模型的预测结果(pred
)。 - 转换预测值和真实标签为字符串格式,去掉填充标签(
-100
)。 - 使用
classification_report
输出测试集上的分类性能指标(如 Precision、Recall、F1-score)。
- 使用测试集的 DataLoader(
-
意义:
- 评估模型在测试集上的性能,量化其分类能力。
with
的作用是什么?
with
是 Python 的上下文管理器,用于管理资源,确保代码块运行后资源能被正确释放或恢复初始状态。- 通俗易懂的解释:
- 想象我们开了一个房间的灯(资源),用完之后需要关灯(释放资源)。
with
确保在使用完资源后,灯(资源)总会被关掉,不管中间是否发生错误。
- 在代码中的应用:
- with torch.no_grad():
# 在这个代码块中禁用梯度计算
... torch.no_grad()
:关闭 PyTorch 的自动梯度计算机制,减少内存占用并加速推理。- 一旦离开
with
块,梯度计算功能会自动恢复。
- with torch.no_grad():
tqdm(test_dataloader)
tqdm
是一个进度条工具:- 它会在循环中显示当前的迭代进度,让用户知道代码执行到了哪一步。
test_dataloader
:- 通常是 PyTorch 的
DataLoader
,用于加载测试数据。 - 它会将测试集数据分成小批次,每次迭代输出一批数据(
X, y
)。
- 通常是 PyTorch 的
- 结合起来:
tqdm(test_dataloader)
在每次迭代时,会动态更新进度条,让用户直观地看到测试集推理的进度。
X, y = X.to(device), y.to(device)
- 功能:
- 将数据
X
和标签y
从 CPU 转移到目标设备(如 GPU 或 CPU)。
- 将数据
- 解释:
X.to(device)
:- 将
X
转移到device
,这是模型运行所在的设备。
- 将
y.to(device)
:- 将标签
y
转移到同一设备上,确保计算时设备一致。
- 将标签
- 是否是赋值?
- 是的,这是一种就地赋值操作:
X.to(device)
将结果赋值回X
。y.to(device)
将结果赋值回y
。
- 是的,这是一种就地赋值操作:
- 为什么要这样写?
- PyTorch 张量需要与模型在同一设备上进行计算,否则会报错。
pred = model(X)
- 是的,
model
是前面定义的模型实例,并且它已经通过model.eval()
切换到了评估模式。 - 解释:
model(X)
表示将输入X
传递给模型model
,进行前向传播,生成预测结果。- 此时模型的评估模式会关闭 Dropout 和 Batch Normalization 的实时更新。
predictions = pred.argmax(dim=-1).cpu().numpy().tolist() 和 labels = y.cpu().numpy().tolist()
功能对比:
代码 | 含义 |
---|---|
predictions = pred.argmax(dim=-1)... | 获取模型的预测结果,每个 token 的预测类别索引。 |
labels = y.cpu().numpy().tolist() | 获取每个 token 的真实标签索引,将张量格式转换为 NumPy 数组和列表。 |
关键部分解释:
pred.argmax(dim=-1)
:- 从模型输出
pred
中找到每个 token 的预测类别索引。 dim=-1
表示在类别维度上取最大值。
- 从模型输出
.cpu()
:- 将张量从 GPU 移动到 CPU。
.numpy()
:- 将 PyTorch 张量转换为 NumPy 数组。
.tolist()
:- 将 NumPy 数组转换为 Python 列表,方便后续操作。
true_labels += [[id2label[int(l)] for l in label if l != -100] for label in labels]
true_labels += [
[id2label[int(l)] for l in label if l != -100]
for label in labels
]
这段代码可以看作是嵌套的 列表推导式(list comprehension),其中 for
循环写在中括号里,它提供了一种紧凑的方式来生成新的列表。
逐步解析
1. 整体结构
-
外层列表推导式:
-
[<内层操作> for label in labels]
- 遍历
labels
中的每个元素label
。 - 对每个
label
进一步处理,生成一个新列表。
-
- 内层列表推导式:
- [id2label[int(l)] for l in label if l != -100]
- [操作 for 变量 in 可迭代对象 if 条件]
- 遍历
label
中的每个元素l
。 - 如果
l != -100
,将其转换为对应的标签字符串并添加到列表中。
2. 外层循环 for label in labels
- 遍历
labels
中的每个子列表label
,每个label
可能对应一个样本的标签序列。 - 例子:
- 第一次循环时:
- label = [2, -100, 1, 0] # 样本 1
- 第二次循环时:
- label = [0, 3, -100, -100] # 样本 2
- 第一次循环时:
3. 内层循环 for l in label
- 遍历当前子列表
label
中的每个元素l
,处理其中的标签。 if l != -100
:- 过滤掉值为
-100
的无效标签。
- 过滤掉值为
4. 内层操作 id2label[int(l)]
int(l)
:- 将标签
l
转换为整数。
- 将标签
id2label[int(l)]
:- 通过字典
id2label
将标签索引(如2
,1
,0
)映射为字符串标签(如"B-LOC"
,"I-LOC"
,"O"
)。
- 通过字典
5. 整体执行流程
以 labels
的例子来说明:
labels = [
[2, -100, 1, 0], # 样本 1 的标签序列
[0, 3, -100, -100] # 样本 2 的标签序列
]
外层循环: 遍历 labels
中的每个子列表 label
。
- 第一次循环:
- 当前
label = [2, -100, 1, 0]
- 内层循环: 遍历
label
中的每个l
,过滤掉-100
,将其转换为字符串标签:- [id2label[int(l)] for l in [2, -100, 1, 0] if l != -100]
→ ["B-LOC", "I-LOC", "O"]
- [id2label[int(l)] for l in [2, -100, 1, 0] if l != -100]
- 结果:
["B-LOC", "I-LOC", "O"]
- 当前
- 第二次循环:
- 当前
label = [0, 3, -100, -100]
- 内层循环: 遍历
label
中的每个l
,过滤掉-100
,将其转换为字符串标签:- [id2label[int(l)] for l in [0, 3, -100, -100] if l != -100]
→ ["O", "B-PER"] - 结果:
["O", "B-PER"]
- [id2label[int(l)] for l in [0, 3, -100, -100] if l != -100]
- 当前
- 最终结果:
- [
["B-LOC", "I-LOC", "O"], # 样本 1 的标签
["O", "B-PER"] # 样本 2 的标签
]
- [
true_predictions += [...]
true_predictions += [
[id2label[int(p)] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
逐步分解:
-
zip(prediction, label)
:- 将预测结果和真实标签逐对组合(对应同一个 token)。
- 例如:
prediction = [2, 0, 1] label = [2, -100, 1] zip(prediction, label) # [(2, 2), (0, -100), (1, 1)]
-
内层列表解析:
- 遍历每个
(p, l)
对:- 如果真实标签
l != -100
(过滤无效标记)。 - 将预测类别索引
p
转换为字符串标签(id2label[int(p)]
)。
- 如果真实标签
- 遍历每个
-
外层列表解析:
- 对每个句子的预测结果
prediction
和真实标签label
,生成一个列表,包含有效 token 的预测标签。
- 对每个句子的预测结果
- 作用:
- 将预测标签从整数索引转换为字符串标签,同时对齐真实标签,去掉无效标记。
逐条生成预测结果
print('predicting labels...')
for s_idx in tqdm(range(len(test_data))):
example = test_data[s_idx]
inputs = tokenizer(example['sentence'], truncation=True, return_tensors="pt")
inputs = inputs.to(device)
pred = model(inputs)
probabilities = torch.nn.functional.softmax(pred, dim=-1)[0].cpu().numpy().tolist()
predictions = pred.argmax(dim=-1)[0].cpu().numpy().tolist()
-
功能:
- 遍历测试集中的每个样本(
example
)。 - 对样本的句子进行分词(
tokenizer
),将其转化为模型输入格式。 - 输入模型,获取每个 token 的预测概率(
probabilities
)和预测类别索引(predictions
)。
- 遍历测试集中的每个样本(
-
意义:
- 将每个句子逐一送入模型,生成 token 级别的预测结果。
1. for s_idx in tqdm(range(len(test_data))):
解释
-
len(test_data)
:-
获取测试数据集
test_data
中的样本数量。 -
假设
test_data
是一个包含多个句子或样本的列表。
-
-
range(len(test_data))
:-
生成从
0
到len(test_data)-1
的索引序列,用于逐一访问每个样本。
-
-
tqdm(range(...))
:-
为循环添加一个动态进度条,显示测试数据集处理的进度。
-
-
s_idx
:-
循环变量,表示当前处理的样本在
test_data
中的索引。
-
作用
-
遍历测试集中的每个样本,逐一处理并生成预测结果。
2. example = test_data[s_idx]
解释
-
test_data[s_idx]
:-
通过索引
s_idx
获取测试集中第s_idx
个样本。
-
-
example
:-
当前处理的样本,通常是一个字典,包含句子文本和对应的标签信息:
-
example = {
"sentence": "日本外务省发布声明。",
"labels": [{"entity_group": "ORG", "start": 0, "end": 5}]
}
-
作用
-
提取当前样本的内容,以便后续进行分词、推理和实体提取。
3. inputs = tokenizer(example['sentence'], truncation=True, return_tensors="pt")
解释
-
tokenizer
:-
一个分词器实例,用于将文本(
example['sentence']
)转换为模型所需的输入格式。
-
-
example['sentence']
:-
当前样本中的句子文本,例如
"日本外务省发布声明。"
.
-
-
参数:
-
truncation=True
:-
如果句子长度超过模型支持的最大长度,则截断句子。
-
-
return_tensors="pt"
:-
将分词结果转换为 PyTorch 的张量格式,供模型使用。
-
-
作用
-
将原始句子转化为模型的输入格式,包含
input_ids
、attention_mask
等张量。
4. inputs = inputs.to(device)
解释
-
inputs
:-
上一步生成的模型输入,包含
input_ids
、attention_mask
等张量。
-
-
to(device)
:-
将输入数据转移到指定设备(如 GPU 或 CPU)。
-
作用
-
确保输入数据和模型位于同一设备上,避免计算时的设备不匹配错误。
5. pred = model(inputs)
解释
-
model(inputs)
:-
将分词后的输入数据传递给模型,进行前向传播,生成预测结果。
-
-
pred
:-
模型输出的 logits(未经过 softmax 处理的预测值)。
-
形状通常为
[batch_size, sequence_length, num_labels]
,每个 token 对应一个类别的预测值。
-
作用
-
对当前句子进行推理,获得每个 token 的预测结果。
6. probabilities = torch.nn.functional.softmax(pred, dim=-1)[0].cpu().numpy().tolist()
解释
-
torch.nn.functional.softmax(pred, dim=-1)
:-
对
pred
中的 logits 进行 softmax 运算,计算每个类别的预测概率。 -
dim=-1
:在 logits 的最后一维(类别维度)上进行 softmax。
-
-
[0]
:-
提取第一个样本的预测结果(通常是因为 batch_size=1)。
-
-
.cpu()
:-
将张量从 GPU 转移到 CPU,便于后续操作。
-
-
.numpy()
:-
将 PyTorch 张量转换为 NumPy 数组。
-
-
.tolist()
:-
将 NumPy 数组转换为 Python 列表。
-
作用
-
将 logits 转换为每个 token 的概率分布,以列表形式存储。
7. predictions = pred.argmax(dim=-1)[0].cpu().numpy().tolist()
解释
-
pred.argmax(dim=-1)
:-
从 logits 中取最大值对应的索引,表示每个 token 的预测类别索引。
-
-
[0]
:-
提取第一个样本的预测结果。
-
-
.cpu().numpy().tolist()
:-
同样,将结果从 GPU 张量转换为 Python 列表。
-
作用
-
获取每个 token 的预测类别索引,以列表形式存储。
功能分段
- 进度显示:
- 使用
tqdm
显示测试集处理进度。
- 使用
- 样本提取:
- 从
test_data
中逐一提取句子及其相关信息。
- 从
- 分词和输入转换:
- 使用
tokenizer
将句子转化为模型输入格式。
- 使用
- 推理:
- 将输入数据送入模型,生成 logits。
- 结果处理:
- 计算每个 token 的概率分布(
probabilities
)。 - 获取每个 token 的预测类别索引(
predictions
)。
- 计算每个 token 的概率分布(
核心意义
- 逐条处理测试集中的每个样本,生成结构化的预测结果(包括概率分布和预测类别)。这为后续的实体提取和结果分析奠定了基础。
实体提取和后处理
pred_label = []
inputs_with_offsets = tokenizer(example['sentence'], return_offsets_mapping=True)
tokens = inputs_with_offsets.tokens()
offsets = inputs_with_offsets["offset_mapping"]
idx = 0
while idx < len(predictions):
pred = predictions[idx]
label = id2label[pred]
if label != "O":
label = label[2:] # Remove the B- or I-
start, end = offsets[idx]
all_scores = [probabilities[idx][pred]]
while (
idx + 1 < len(predictions) and
id2label[predictions[idx + 1]] == f"I-{label}"
):
all_scores.append(probabilities[idx + 1][predictions[idx + 1]])
_, end = offsets[idx + 1]
idx += 1
score = np.mean(all_scores).item()
word = example['sentence'][start:end]
pred_label.append(
{
"entity_group": label,
"score": score,
"word": word,
"start": start,
"end": end,
}
)
idx += 1
-
功能:
- 根据模型输出的
predictions
和offset_mapping
提取实体:- 如果当前 token 属于某实体(标签非
O
),提取实体类别、文本片段、起止位置和置信度。 - 合并属于同一实体的连续 token。
- 如果当前 token 属于某实体(标签非
- 将提取出的实体信息(类别、置信度、文本、起止位置)存储为字典,添加到
pred_label
列表中。
- 根据模型输出的
-
意义:
- 将 token 级别的预测结果整合为实体级别的结果,生成结构化的输出。
1. pred_label = []
-
作用:
-
初始化一个空列表,用于存储非
"O"
类实体的相关信息。 -
后续会将所有识别出的实体(类别、文本、位置、置信度等)作为字典存入
pred_label
。
-
-
补充说明:
-
它并不是一个空集,而是一个空列表。
-
后续会在循环中用
pred_label.append()
添加实体信息。
-
2. inputs_with_offsets = tokenizer(example['sentence'], return_offsets_mapping=True)
-
作用:
-
使用分词器(
tokenizer
)对句子example['sentence']
进行分词,同时返回每个 token 在原句中的字符偏移位置。
-
-
关键参数:
-
return_offsets_mapping=True
:-
指定返回每个 token 在原始句子中的字符范围(即
offset_mapping
)。
-
-
-
结果:
-
inputs_with_offsets
是一个包含分词结果的对象,常包含以下内容:-
input_ids
:token 的 ID。 -
offset_mapping
:每个 token 的字符起始和结束位置。
-
-
3. tokens = inputs_with_offsets.tokens()
-
作用:
-
提取
inputs_with_offsets
中的 token 列表(即分词结果)。
-
-
结果:
-
tokens
是一个列表,包含句子中所有 token 的字符串形式,例如: -
tokens = ["▁日本", "外务省", "发布", "声明", "。"]
-
4. offsets = inputs_with_offsets["offset_mapping"]
-
作用:
-
提取每个 token 的字符偏移信息。
-
-
结果:
-
offsets
是一个列表,表示每个 token 在原句中的起始和结束位置,例如: -
offsets = [(0, 2), (2, 5), (5, 7), (7, 9), (9, 10)]
-
5. while 循环的功能
idx = 0
while idx < len(predictions):
pred = predictions[idx]
label = id2label[pred]
if label != "O":
label = label[2:] # Remove the B- or I-
start, end = offsets[idx]
all_scores = [probabilities[idx][pred]]
while (
idx + 1 < len(predictions) and
id2label[predictions[idx + 1]] == f"I-{label}"
):
all_scores.append(probabilities[idx + 1][predictions[idx + 1]])
_, end = offsets[idx + 1]
idx += 1
score = np.mean(all_scores).item()
word = example['sentence'][start:end]
pred_label.append(
{
"entity_group": label,
"score": score,
"word": word,
"start": start,
"end": end,
}
)
idx += 1
1. pred = predictions[idx]
- 获取当前索引
idx
处的预测类别索引。
2. label = id2label[pred]
- 使用字典
id2label
将预测类别索引pred
转换为对应的标签(如"B-LOC"
、"O"
)。
3. label = label[2:]
- 移除标签前缀
"B-"
或"I-"
,保留实体类别(如"LOC"
、"PER"
)。
4. all_scores = [probabilities[idx][pred]]
- 初始化一个列表,存储当前 token 对预测类别的概率。
-
probabilities
:- 是一个二维列表或数组,存储每个 token 对所有类别的预测概率。
- 形状通常为
[sequence_length, num_labels]
,即:- 每一行表示一个 token。
- 每一列表示一个类别的概率。
probabilities[idx][pred]
的含义- 取出当前 token 对预测类别的概率:
probabilities[idx]
:- 获取当前 token 的概率分布。
- 例如,当
idx=1
时:- probabilities[idx] → [0.3, 0.4, 0.3]
probabilities[idx][pred]
:- 获取当前 token 对预测类别
pred
的概率。 - 例如,当
pred=1
时:- probabilities[idx][pred] → 0.4
- 获取当前 token 对预测类别
- 取出当前 token 对预测类别的概率:
5. 嵌套 while
循环
- 检查后续 token 是否属于同一实体:
- 如果后续 token 的标签是
"I-<实体类别>"
,则:- 将其概率添加到
all_scores
。 - 更新实体的结束位置为该 token 的结束索引。
- 移动
idx
到下一个 token。
- 将其概率添加到
- 如果后续 token 的标签是
6. score = np.mean(all_scores).item()
- 计算实体中所有 token 的平均概率,作为该实体的置信度。
7. word = example['sentence'][start:end]
- 提取实体对应的文本片段。
8. pred_label.append({...})
- 将当前实体的信息(类别、文本、起始位置、置信度等)存储为字典,添加到
pred_label
列表中。
循环的整体功能
-
处理非
"O"
类 token:- 识别所有非
"O"
类的实体 token,合并属于同一实体的连续 token。
- 识别所有非
-
计算实体信息:
- 记录实体的类别(去掉前缀)、文本、起止位置和置信度。
-
生成实体列表:
- 将所有识别出的实体信息存储到
pred_label
列表中。
- 将所有识别出的实体信息存储到
注意
这段代码的功能不是处理 "O"
类实体:
-
代码中的条件
if label != "O":
表示只处理非"O"
类实体。 -
"O"
类 token(即非实体部分)会被跳过。
主要目标是提取实体信息:
-
将所有标注为
"B-<类别>"
或"I-<类别>"
的 token 合并为完整的实体,并记录其起止位置。
存储结果
results.append(
{
"sentence": example['sentence'],
"pred_label": pred_label,
"true_label": example['labels']
}
)
with open('test_data_pred.json', 'wt', encoding='utf-8') as f:
for exapmle_result in results:
f.write(json.dumps(exapmle_result, ensure_ascii=False) + '\n')
-
功能:
- 将每个句子的预测结果(包括原始句子、预测实体、真实标签)存储为字典。
- 将所有结果写入 JSON 文件
test_data_pred.json
。
-
意义:
- 将模型的预测结果保存为文件,方便后续分析、展示或评估。
example[]
解释
example[]
是一个字典的访问方式,而不是数组。example
通常是一个字典,包含当前样本的信息。例如:- example = {
"sentence": "日本外务省发布声明。",
"labels": [
{"entity_group": "ORG", "start": 0, "end": 5}
]
}
- example = {
example['sentence']
:- 访问字典中键为
"sentence"
的值,这里是句子"日本外务省发布声明。"
.
- 访问字典中键为
example['labels']
:- 访问字典中键为
"labels"
的值,这里是一个包含实体信息的列表。
- 访问字典中键为
with open('test_data_pred.json', 'wt', encoding='utf-8') as f:
for example_result in results:
f.write(json.dumps(example_result, ensure_ascii=False) + '\n')
with open('test_data_pred.json', 'wt', encoding='utf-8') as f:
with
:- 用于上下文管理,确保文件在操作结束后被正确关闭。
- 即使中途代码发生错误,文件也会自动关闭,避免资源泄漏。
-
为什么要用
with
?-
确保资源管理安全:
-
使用
with
可以确保文件在操作完成后自动关闭,即使中途发生错误。
-
-
等价代码:
-
f = open('test_data_pred.json', 'wt', encoding='utf-8')
try:
for example_result in results:
f.write(json.dumps(example_result, ensure_ascii=False) + '\n')
finally:
f.close()
-
-
open('test_data_pred.json', 'wt', encoding='utf-8')
:- 打开一个名为
test_data_pred.json
的文件,准备写入数据。 - 参数解析:
'wt'
:w
表示以写入模式打开文件。t
表示以文本模式打开文件。
encoding='utf-8'
:- 指定文件的编码方式为 UTF-8,支持写入中文等字符。
- 打开一个名为
as f
:- 将打开的文件对象赋值给变量
f
。 - 在
with
语句块内,f
就代表这个文件对象。
- 将打开的文件对象赋值给变量
作用:
- 打开文件,并确保在完成操作后自动关闭。
for example_result in results:
- 功能:
- 遍历
results
列表,将列表中的每个元素依次赋值给变量example_result
。
- 遍历
- 假设
results
的结构:results
是一个列表,每个元素是一个字典,包含句子及其预测结果:- results = [
{
"sentence": "日本外务省发布声明。",
"pred_label": [
{"entity_group": "ORG", "score": 0.99, "word": "日本外务省", "start": 0, "end": 5}
],
"true_label": [
{"entity_group": "ORG", "word": "日本外务省", "start": 0, "end": 5}
]
}
]
for
循环:
- 每次循环,
example_result
会获取results
中的一个字典元素。 - 例如:
- 第一次循环时:
- example_result = {
"sentence": "日本外务省发布声明。",
"pred_label": [
{"entity_group": "ORG", "score": 0.99, "word": "日本外务省", "start": 0, "end": 5}
],
"true_label": [
{"entity_group": "ORG", "word": "日本外务省", "start": 0, "end": 5}
]
}
-
作用:
- 遍历
results
,逐一处理每个预测结果。
- 遍历
f.write(json.dumps(example_result, ensure_ascii=False) + '\n')
- 功能:
- 将当前样本的预测结果写入文件。
- 分解解释:
json.dumps(example_result, ensure_ascii=False)
:- 将字典
example_result
转换为 JSON 格式的字符串。 ensure_ascii=False
:- 确保生成的 JSON 字符串可以直接显示非 ASCII 字符(如中文)。
- 将字典
+ '\n'
:- 在 JSON 字符串后添加换行符,确保每个样本占用文件中的一行。
f.write(...)
:- 将生成的 JSON 字符串写入文件。
作用:
- 将
results
列表中的每个样本结果写入文件,存储为 JSON 格式。
代码的核心意义
-
测试模型性能:
- 对测试集进行推理,评估模型的分类性能(Precision、Recall、F1-score)。
-
生成可视化的预测结果:
- 从 token 级别的预测结果整合为实体级别的结果,提供更直观和实用的输出。
-
结果存储和共享:
- 将预测结果与真实标签一起保存为 JSON 文件,便于后续分析和使用。