先附上代码,代码注释中会有一点讲解,详细的讲解在代码下面
import argparse
import os
import numpy as np
import pandas as pd
from pandas import DataFrame
import os
import json
from collections import Counter, OrderedDict
import logging
logger = logging.getLogger(__name__) # 1
def get_labels(path, name, negative_label="no_relation"): # 参数negative_label的默认值是no_relation
"""
这个函数的作用:
这个函数将打开数据集文件目录下的train.txt文件,这个文件用来训练模型
文件中每一行都是用字符串表示的一个字典,字典中4个键值对,分别是token、头实体、尾实体、关系
将这些字符串表示的字典转成python数据类型字典,然后存到列表feature中
然后返回feature列表
"""
count = Counter() # 2
with open(path + "/" + name, "r") as f: # 3
features = []
for line in f.readlines(): # 4
line = line.rstrip() # 移除字符串末尾的空白字符,包括空格、制表符和换行符等等
if len(line) > 0:
# count[line['relation']] += 1
features.append(eval(line)) # 5
# logger.info("label distribution as list: %d labels" % len(count))
# # Make sure the negative label is alwyas 0
# labels = []
# for label, count in count.most_common():
# logger.info("%s: %d 个 %.2f%%" % (label, count, count * 100.0 / len(dataset)))
# if label not in labels:
# labels.append(label)
"""
注释掉的代码似乎是用来统计文件中各种标签出现的频率并打印
相关信息。然而,这部分代码目前不参与实际的功能实现。可能因此被注释掉
"""
return features
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--k", type=int, default=16,
help="Training examples for each class.") # 每个类别训练k个样本
# parser.add_argument("--task", type=str, nargs="+",
# default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'],
# help="Task names")
parser.add_argument("--seed", type=int, nargs="+",
default=[1, 2, 3, 4, 5],
help="Random seeds") # nargs="+"表示seed这个参数允许接收多个值,这些值存在一个列表中
# data_dir是存放所有数据集的那个目录的名字,比如dataset
parser.add_argument("--data_dir", type=str, default="dataset/",
help="Path to original data")
# dataset是我们要用的数据集的名字
parser.add_argument("--dataset", type=str, default="semeval",
help="Path to original data")
# data_file这个参数的值只能是train.txt或val.txt
parser.add_argument("--data_file", type=str, default='train.txt',
choices=['train.txt', 'val.txt'], help="k-shot or k-shot-10x (10x dev set)")
# k shot表示每轮训练中,每个类别用k个样本;10x表示用10*k个样本
parser.add_argument("--mode", type=str, default='k-shot',
choices=['k-shot', 'k-shot-10x'], help="k-shot or k-shot-10x (10x dev set)")
args = parser.parse_args()
path = os.path.join(args.data_dir, args.dataset) # 默认值dataset/semeval
output_dir = os.path.join(path, args.mode) # 默认值dataset/semeval/k-shot
dataset = get_labels(path, args.data_file)
for seed in args.seed:
# Other datasets
np.random.seed(seed) # 设置Numpy随机数生成器的种子
np.random.shuffle(dataset) # 对列表dataset进行随机打乱
# 设置输出路径
k = args.k
setting_dir = os.path.join(output_dir, f"{k}-{seed}") # 默认值dataset/semeval/k-shot/k-seed
os.makedirs(setting_dir, exist_ok=True) # 6
label_list = {}
for line in dataset: # dataset:[{"token":[...],"h":[...],"e":[...],"relation":[...]},{...},...]
label = line['relation'] # 取出键“relation”对应的值,relation的值只有1个
if label not in label_list:
label_list[label] = [line] # 7
else:
label_list[label].append(line)
with open(os.path.join(setting_dir, "train.txt"), "w") as f:
file_list = []
# 8
for label in label_list:
for line in label_list[label][:k]:
f.writelines(json.dumps(line)) # 将line这个字典转换成json格式字符串,然后写入文件f中
# 不用writelines直接用write也可以,用write就可以省略下面那句代码了
f.write('\n')
# 结束上面这个双重循环后,我们就得到每个标签对应的k个训练样本了
f.close()
if __name__ == "__main__":
main()
generate_k_shot.py文件概述
数据集semeval目录下的train.txt文件中存放着用来训练模型的数据,这个文件长这样:
每一行都由4部分组成:对一个句子进行分词后得到的token、头实体、尾实体、实体间的关系
然后把这个文件中,“关系”一样的数据提出来,用一个字典去存放,键是关系名,值是一个列表,列表中存放着train.txt文件中所有该关系下的数据
然后对于每一个关系,只取其前k条数据,用于训练,这些数据就存在semeval目录下的k-shot目录中,k就代表的k个shot,也就是一个标签下用k条数据
1、logger = logging.getLogger(_ _name_ _)
创建一个名为 __name__ 的日志记录器对象,通常用于记录程序的运行信息和错误信息。logger 对象是 Python 内置的 logging 模块提供的一种机制,用于控制日志的记录级别、格式和输出目标
2、count = Counter()
count = Counter() 表示创建了一个名为 count 的变量,并将其初始化为一个空的计数器对象。在这里,Counter 是 Python 标准库中的一个类,用于统计可迭代对象中元素的出现次数
具体来说,Counter 可以用来统计列表、元组、字符串等可迭代对象中每个元素的出现次数,并以字典的形式返回结果,其中键是元素,值是元素的出现次数
举例:
3、从下面的main()中看看这个path和name具体代表的是什么
data_dir这个命令行参数的默认值有问题,我们的目录名是“dataset”,这里的默认值是“../datasets”,用“../”都返回到根目录KnowPromptCode的上一级上去了,应该是不对的,改过来,就改成dataset就行
dataset这个命令行参数的默认值也要改一下,我们在生成标签词的时候用的semeval这个数据集,这里用的tacred这个数据集,改过来
再往下看可以看到“path = os.path.join(args.data_dir, args.dataset)”,于是path=dataset+semeval,即path是dataset/semeval
调用这个函数的时候使用的“dataset = get_labels(path, args.data_file)”这句代码,args.data_file的默认值是train.txt。所以这句代码将打开dataset/semeval/train.txt这个文件,这个文件太长了,分成两张图来展示
可以看到这个文件每一行都是一个字典,有4个键值对。键名分别是token、h、t、relation
token一般表示,文本中的基本单元,就是比如说有个句子是“i feel happy”,对这个句子进行分词后,得到“i”、“feel”、“happy”这个三个单词,每个单词就是一个token
h和t分别代表“头实体”(head entity)和“尾实体”(tail entity)的信息,包括名称和位置(在标记列表中的索引范围
relation表示头实体和尾实体之间关系的标签
4、for line in f.readlines():
f.readlines()读取文件中的所有行并返回一个包含这些行的列表。例如:
当我们使用readlines()这个方法读取文件的时候
得到的输出是
5、eval(line)
eval函数的参数是字符串,eval函数的作用就是执行这个字符串表示的python代码
比如line是字符串“5+3”,那执行函数将得到8
再比line是字符串“{‘name’:’Alice’ , ’age’:’30’ , ’city’:’NY’}”
那么执行eval函数后就会得到python中的“字典”这一数据类型
我们上面说了train.txt中的每一行都是一个表示字典的字符串,那么执行这个代码后就会得到字典这个python数据类型。然后把这个字典加入到features这个列表中
6、os.makedirs(setting_dir, exist_ok=True)
创建一个名为 setting_dir 的目录。如果该目录已经存在,它会继续执行而不引发错误;如果目录不存在,则创建该目录
7、label_list[label] = [line]
举例解释:假设有一个dataset长这样
line遍历到第一行时
line[‘relation’]=label=founder_of
label_list[label]=label_list[founder_of]=[line]
于是label_list变成
可以看到,label_list这个字典多出了一个键值对,键就是“founder_of”,值是一个列表,列表中的元素一个字典,这个字典来源于dataset,是dataset中的第一个字典
遍历完上例所示的整个dataset后,得到字典label_list
8、举例解释这个双重for循环
假设label_list如上图所示
“for label in label_list:”的第一次遍历得到“founder_of”
然后for line in label_list[label][:k]:
假设k=2,那就是
for line in label_list[founder_of][:2]:
label_list[founder_of]就是键founder_of的值,即一个列表,然后“[:2]”就是对这个列表做切片,取这个列表的0、1号元素,于是得到下面这个列表
然后每一次遍历得到的line就是上面这个列表中的一个字典