《KnowPrompt》论文代码复现4-generate_k_shot.py代码讲解(超级详细!)

先附上代码,代码注释中会有一点讲解,详细的讲解在代码下面

import argparse
import os
import numpy as np
import pandas as pd
from pandas import DataFrame
import os
import json
from collections import Counter, OrderedDict
import logging
logger = logging.getLogger(__name__) # 1


def get_labels(path, name,  negative_label="no_relation"): # 参数negative_label的默认值是no_relation
    """
        这个函数的作用:
        这个函数将打开数据集文件目录下的train.txt文件,这个文件用来训练模型
        文件中每一行都是用字符串表示的一个字典,字典中4个键值对,分别是token、头实体、尾实体、关系
        将这些字符串表示的字典转成python数据类型字典,然后存到列表feature中
        然后返回feature列表
    """
    count = Counter() # 2
    with open(path + "/" + name, "r") as f: # 3
        features = []
        for line in f.readlines(): # 4
            line = line.rstrip() # 移除字符串末尾的空白字符,包括空格、制表符和换行符等等
            if len(line) > 0:
                # count[line['relation']] += 1
                features.append(eval(line)) # 5

    # logger.info("label distribution as list: %d labels" % len(count))
    # # Make sure the negative label is alwyas 0
    # labels = []
    # for label, count in count.most_common():
    #     logger.info("%s: %d 个 %.2f%%" % (label, count,  count * 100.0 / len(dataset)))
    #     if label not in labels:
    #         labels.append(label)
    """
        注释掉的代码似乎是用来统计文件中各种标签出现的频率并打印
        相关信息。然而,这部分代码目前不参与实际的功能实现。可能因此被注释掉
    """
    return features

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--k", type=int, default=16,
        help="Training examples for each class.") # 每个类别训练k个样本
    # parser.add_argument("--task", type=str, nargs="+",
    #     default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'],
    #     help="Task names")
    parser.add_argument("--seed", type=int, nargs="+",
        default=[1, 2, 3, 4, 5],
        help="Random seeds") # nargs="+"表示seed这个参数允许接收多个值,这些值存在一个列表中

    # data_dir是存放所有数据集的那个目录的名字,比如dataset
    parser.add_argument("--data_dir", type=str, default="dataset/",
                        help="Path to original data")

    # dataset是我们要用的数据集的名字
    parser.add_argument("--dataset", type=str, default="semeval",
                        help="Path to original data")

    # data_file这个参数的值只能是train.txt或val.txt
    parser.add_argument("--data_file", type=str, default='train.txt',
                        choices=['train.txt', 'val.txt'], help="k-shot or k-shot-10x (10x dev set)")

    # k shot表示每轮训练中,每个类别用k个样本;10x表示用10*k个样本
    parser.add_argument("--mode", type=str, default='k-shot',
                        choices=['k-shot', 'k-shot-10x'], help="k-shot or k-shot-10x (10x dev set)")

    args = parser.parse_args()

    path = os.path.join(args.data_dir, args.dataset) # 默认值dataset/semeval
    output_dir = os.path.join(path, args.mode) # 默认值dataset/semeval/k-shot
    dataset = get_labels(path, args.data_file)

    for seed in args.seed:

        # Other datasets
        np.random.seed(seed) # 设置Numpy随机数生成器的种子
        np.random.shuffle(dataset) # 对列表dataset进行随机打乱

        # 设置输出路径
        k = args.k
        setting_dir = os.path.join(output_dir, f"{k}-{seed}") # 默认值dataset/semeval/k-shot/k-seed
        os.makedirs(setting_dir, exist_ok=True) # 6

        label_list = {}
        for line in dataset: # dataset:[{"token":[...],"h":[...],"e":[...],"relation":[...]},{...},...]
            label = line['relation'] # 取出键“relation”对应的值,relation的值只有1个

            if label not in label_list:
                label_list[label] = [line] # 7
            else:
                label_list[label].append(line)

        with open(os.path.join(setting_dir, "train.txt"), "w") as f:
            file_list = []

            # 8
            for label in label_list:
                for line in label_list[label][:k]:

                    f.writelines(json.dumps(line)) # 将line这个字典转换成json格式字符串,然后写入文件f中
                    # 不用writelines直接用write也可以,用write就可以省略下面那句代码了
                    f.write('\n')

            # 结束上面这个双重循环后,我们就得到每个标签对应的k个训练样本了

            f.close()


if __name__ == "__main__":
    main()

generate_k_shot.py文件概述

数据集semeval目录下的train.txt文件中存放着用来训练模型的数据,这个文件长这样:

每一行都由4部分组成:对一个句子进行分词后得到的token、头实体、尾实体、实体间的关系

然后把这个文件中,“关系”一样的数据提出来,用一个字典去存放,键是关系名,值是一个列表,列表中存放着train.txt文件中所有该关系下的数据

然后对于每一个关系,只取其前k条数据,用于训练,这些数据就存在semeval目录下的k-shot目录中,k就代表的k个shot,也就是一个标签下用k条数据

1、logger = logging.getLogger(_ _name_ _)

创建一个名为 __name__ 的日志记录器对象,通常用于记录程序的运行信息和错误信息。logger 对象是 Python 内置的 logging 模块提供的一种机制,用于控制日志的记录级别、格式和输出目标

2、count = Counter()

count = Counter() 表示创建了一个名为 count 的变量,并将其初始化为一个空的计数器对象。在这里,Counter 是 Python 标准库中的一个类,用于统计可迭代对象中元素的出现次数

具体来说,Counter 可以用来统计列表、元组、字符串等可迭代对象中每个元素的出现次数,并以字典的形式返回结果,其中键是元素,值是元素的出现次数

举例:

3、从下面的main()中看看这个path和name具体代表的是什么

data_dir这个命令行参数的默认值有问题,我们的目录名是“dataset”,这里的默认值是“../datasets”,用“../”都返回到根目录KnowPromptCode的上一级上去了,应该是不对的,改过来,就改成dataset就行

dataset这个命令行参数的默认值也要改一下,我们在生成标签词的时候用的semeval这个数据集,这里用的tacred这个数据集,改过来

再往下看可以看到“path = os.path.join(args.data_dir, args.dataset)”,于是path=dataset+semeval,即path是dataset/semeval

调用这个函数的时候使用的“dataset = get_labels(path, args.data_file)”这句代码,args.data_file的默认值是train.txt。所以这句代码将打开dataset/semeval/train.txt这个文件,这个文件太长了,分成两张图来展示

可以看到这个文件每一行都是一个字典,有4个键值对。键名分别是token、h、t、relation

token一般表示,文本中的基本单元,就是比如说有个句子是“i feel happy”,对这个句子进行分词后,得到“i”、“feel”、“happy”这个三个单词,每个单词就是一个token

h和t分别代表“头实体”(head entity)和“尾实体”(tail entity)的信息,包括名称和位置(在标记列表中的索引范围

relation表示头实体和尾实体之间关系的标签

4、for line in f.readlines():

f.readlines()读取文件中的所有行并返回一个包含这些行的列表。例如:

当我们使用readlines()这个方法读取文件的时候

得到的输出是

5、eval(line)

eval函数的参数是字符串,eval函数的作用就是执行这个字符串表示的python代码

比如line是字符串“5+3”,那执行函数将得到8

再比line是字符串“{‘name’:’Alice’ , ’age’:’30’ , ’city’:’NY’}”

那么执行eval函数后就会得到python中的“字典”这一数据类型

我们上面说了train.txt中的每一行都是一个表示字典的字符串,那么执行这个代码后就会得到字典这个python数据类型。然后把这个字典加入到features这个列表中

6、os.makedirs(setting_dir, exist_ok=True)

创建一个名为 setting_dir 的目录。如果该目录已经存在,它会继续执行而不引发错误;如果目录不存在,则创建该目录

7、label_list[label] = [line]

举例解释:假设有一个dataset长这样

line遍历到第一行时

line[‘relation’]=label=founder_of

label_list[label]=label_list[founder_of]=[line]

于是label_list变成

可以看到,label_list这个字典多出了一个键值对,键就是“founder_of”,值是一个列表,列表中的元素一个字典,这个字典来源于dataset,是dataset中的第一个字典

遍历完上例所示的整个dataset后,得到字典label_list

8、举例解释这个双重for循环

假设label_list如上图所示

“for label in label_list:”的第一次遍历得到“founder_of”

然后for line in label_list[label][:k]:

假设k=2,那就是

for line in label_list[founder_of][:2]:

label_list[founder_of]就是键founder_of的值,即一个列表,然后“[:2]”就是对这个列表做切片,取这个列表的0、1号元素,于是得到下面这个列表

然后每一次遍历得到的line就是上面这个列表中的一个字典

  • 30
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值